!7750 Add a simplification pattern to GraphKernel's ArithSimplify.

Merge pull request !7750 from DeshiChen/1026_simplify_mul
This commit is contained in:
mindspore-ci-bot 2020-10-27 15:59:13 +08:00 committed by Gitee
commit ac3a82006c
1 changed files with 7 additions and 0 deletions

View File

@ -258,6 +258,11 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), new_lhs, new_rhs}, node);
return new_cnode;
};
auto const_dup_lambda2 = [&node, &x, &const_1, &const_2]() -> AnfNodePtr {
auto new_rhs = const_1.MulByPatternConst(const_2, x.GetNode(node));
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
@ -283,6 +288,8 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
};
// (x*C1)*(y*C2) ==> (x*y)*(C1*C2)
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda);
// (x*C1)*C2 ==> x*(C1*C2)
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2);
// exp(x)*exp(y) ==> exp(x+y)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda);
// sqrt(x)*sqrt(x) ==> x