forked from mindspore-Ecosystem/mindspore
!7750 Add a simplification pattern to GraphKernel's ArithSimplify.
Merge pull request !7750 from DeshiChen/1026_simplify_mul
This commit is contained in:
commit
ac3a82006c
|
@ -258,6 +258,11 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
|
|||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), new_lhs, new_rhs}, node);
|
||||
return new_cnode;
|
||||
};
|
||||
auto const_dup_lambda2 = [&node, &x, &const_1, &const_2]() -> AnfNodePtr {
|
||||
auto new_rhs = const_1.MulByPatternConst(const_2, x.GetNode(node));
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node);
|
||||
return new_cnode;
|
||||
};
|
||||
auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
|
||||
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node);
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
|
||||
|
@ -283,6 +288,8 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
|
|||
};
|
||||
// (x*C1)*(y*C2) ==> (x*y)*(C1*C2)
|
||||
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda);
|
||||
// (x*C1)*C2 ==> x*(C1*C2)
|
||||
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2);
|
||||
// exp(x)*exp(y) ==> exp(x+y)
|
||||
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda);
|
||||
// sqrt(x)*sqrt(x) ==> x
|
||||
|
|
Loading…
Reference in New Issue