forked from mindspore-Ecosystem/mindspore
!522 [bug]pass AdjustAllReduceMulAdd replace op is not has a infer
Merge pull request !522 from vlne-v1/I1F21H-pass-adjust_all_reduce_mul_add-repkace-op-is-not-has-a-infer
This commit is contained in:
commit
250dbb0e82
|
@ -248,17 +248,18 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||||
if (addn->size() != 2) {
|
if (addn->size() != 2) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
||||||
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
|
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto addn_op_node = addn->input(0);
|
||||||
|
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
|
||||||
auto fg = node->func_graph();
|
auto fg = node->func_graph();
|
||||||
AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg);
|
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
|
||||||
AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg);
|
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
|
||||||
AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg);
|
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
|
||||||
return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg);
|
return NewCNode({mul_, all_reduce, y_}, fg);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Visit(const AnfNodePtr &node) override {
|
void Visit(const AnfNodePtr &node) override {
|
||||||
|
@ -269,6 +270,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||||
AnfVisitor::Match(prim::kPrimMul)(node);
|
AnfVisitor::Match(prim::kPrimMul)(node);
|
||||||
level_ = 0;
|
level_ = 0;
|
||||||
if (is_reduce_match_) {
|
if (is_reduce_match_) {
|
||||||
|
mul_ = node->cast<CNodePtr>()->input(0);
|
||||||
y_ = tmp_;
|
y_ = tmp_;
|
||||||
} else {
|
} else {
|
||||||
z_ = node;
|
z_ = node;
|
||||||
|
@ -280,6 +282,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
if (cnode->size() > 1) {
|
if (cnode->size() > 1) {
|
||||||
|
all_reduce_ = cnode->input(0);
|
||||||
x_ = cnode->input(1);
|
x_ = cnode->input(1);
|
||||||
is_reduce_match_ = true;
|
is_reduce_match_ = true;
|
||||||
}
|
}
|
||||||
|
@ -302,6 +305,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||||
int level_{0};
|
int level_{0};
|
||||||
bool is_reduce_match_{false};
|
bool is_reduce_match_{false};
|
||||||
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
|
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
|
||||||
|
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
class ArithmeticSimplify {
|
class ArithmeticSimplify {
|
||||||
|
|
Loading…
Reference in New Issue