!7236 Add new pass:arithmetic_simplify and eliminate_empty_graph

Merge pull request !7236 from gengfei/1012_simplify_1.0
This commit is contained in:
mindspore-ci-bot 2020-10-15 14:52:15 +08:00 committed by Gitee
commit f4421b4504
6 changed files with 799 additions and 36 deletions

View File

@ -0,0 +1,556 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include <list>
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/pattern_matcher.h"
#include "frontend/operator/ops.h"
#include "utils/convert_utils.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_node) {
auto func_graph = ori_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto new_cnode = func_graph->NewCNode(inputs);
new_cnode->set_abstract(ori_node->abstract());
new_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
ResetKernelInfo(new_cnode, AKG_KERNEL);
} else {
ResetKernelInfo(new_cnode, UNKNOWN_KERNEL_TYPE);
}
func_graph->AddNode(new_cnode);
return new_cnode;
}
AnfNodePtr SimplifyAdd(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimTensorAdd)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y, z;
PConstant<AnfNodePtr> zero_num(node, false, 0);
PConstant<AnfNodePtr> zero_scalar(node, false, 0, true);
PConstant<AnfNodePtr> any_const(node);
PConstant<AnfNodePtr> any_const_2(node);
auto add_distri_lambda = [&node, &x, &y, &any_const]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), node_tmp, any_const.GetNode(node)}, node);
return new_cnode;
};
auto add_union_lambda = [&node, &x, &any_const, &any_const_2]() -> AnfNodePtr {
auto new_rhs = any_const.AddByPatternConst(any_const_2, x.GetNode(node));
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
// A + 0 = A
MATCH_REPLACE(node, x + zero_num, x);
// A*C + B*C = (A + B)*C
MATCH_REPLACE_LAMBDA(node, (x * any_const) + (y * any_const), add_distri_lambda);
// (A + C1) + C2 = A + (C1 + C2)
MATCH_REPLACE_LAMBDA(node, (x + any_const) + any_const_2, add_union_lambda);
// A + (-A) = 0
MATCH_REPLACE(node, x + PUnaryOperation(prim::kPrimNeg, x), zero_scalar.NewValue());
return nullptr;
}
AnfNodePtr SimplifySub(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimSub)) {
return nullptr;
}
PatternNode<AnfNodePtr> x;
PConstant<AnfNodePtr> zero_num(node, false, 0);
PConstant<AnfNodePtr> any_const(node);
auto sub_toadd_lambda = [&node, &x, &any_const]() -> AnfNodePtr {
auto new_rhs = any_const.ValueNodeWithOprations(prim::kPrimNeg);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
// A - 0 = A
MATCH_REPLACE(node, x - zero_num, x);
// A - const = A + (-const)
MATCH_REPLACE_LAMBDA(node, x - any_const, sub_toadd_lambda);
return nullptr;
}
AnfNodePtr SimplifyNeg(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimNeg)) {
return nullptr;
}
PatternNode<AnfNodePtr> x;
MATCH_REPLACE(node, PUnaryOperation(prim::kPrimNeg, PUnaryOperation(prim::kPrimNeg, x)), x);
return nullptr;
}
AnfNodePtr SimplifyLog(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimLog)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y;
auto ln_front_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAbs), x.GetNode(node)}, node);
auto node_tmp_2 = NewCNodeWithInfo({NewValueNode(prim::kPrimLog), node_tmp}, node);
auto new_cnode =
NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), node_tmp_2}, node->cast<CNodePtr>()->input(1));
return new_cnode;
};
auto sqrt_ln_lambda = [&node, &x]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimLog), x.GetNode(node)}, node);
auto value = MakeValue(std::make_shared<FP32Imm>(0.5));
auto tensor_ptr = mindspore::ScalarToTensor(value->cast<ScalarPtr>());
auto value_node_ptr = MakeValueNode(std::make_shared<ValueNode>(tensor_ptr));
value_node_ptr->set_abstract(node->abstract());
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), value_node_ptr, node_tmp}, node);
return new_cnode;
};
auto rsqrt_ln_lambda = [&node, &x]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimLog), x.GetNode(node)}, node);
auto value = MakeValue(std::make_shared<FP32Imm>(-0.5));
auto tensor_ptr = mindspore::ScalarToTensor(value->cast<ScalarPtr>());
auto value_node_ptr = MakeValueNode(std::make_shared<ValueNode>(tensor_ptr));
value_node_ptr->set_abstract(node->abstract());
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), value_node_ptr, node_tmp}, node);
return new_cnode;
};
// Ln(Exp(A)) = A
MATCH_REPLACE(node, PUnaryOperation(prim::kPrimLog, PUnaryOperation(prim::kPrimExp, x)), x);
// Ln(Pow(A,B)) = B*Ln(Abs(A))
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimLog, PBinOperation(prim::kPrimPow, x, y, false)),
ln_front_lambda);
// Ln(Sqrt(A)) = 0.5*Ln(A)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimLog, PUnaryOperation(prim::kPrimSqrt, x)), sqrt_ln_lambda);
// Ln(Rqrt(A)) = -0.5*Ln(A)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimLog, PUnaryOperation(prim::kPrimRsqrt, x)), rsqrt_ln_lambda);
return nullptr;
}
AnfNodePtr SimplifyPow(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimPow)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y;
PConstant<AnfNodePtr> zero_num(node, false, 0);
PConstant<AnfNodePtr> one_const(node, false, 1);
PConstant<AnfNodePtr> two_const(node, false, 2);
PConstant<AnfNodePtr> negone_const(node, false, -1);
auto pow_zero_lambda = [&node]() -> AnfNodePtr {
auto value = MakeValue(std::make_shared<FP32Imm>(1));
auto tensor_ptr = mindspore::ScalarToTensor(value->cast<ScalarPtr>());
auto value_node_ptr = MakeValueNode(std::make_shared<ValueNode>(tensor_ptr));
value_node_ptr->set_abstract(node->abstract());
return value_node_ptr;
};
auto exp_power_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node->cast<CNodePtr>()->input(1));
return new_cnode;
};
auto squre_power_lambda = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), x.GetNode(node)}, node);
return new_cnode;
};
auto r_power_lambda = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReciprocal), x.GetNode(node)}, node);
return new_cnode;
};
// Pow(A, 0) = 1
MATCH_REPLACE_LAMBDA(node, PBinOperation(prim::kPrimPow, x, zero_num, false), pow_zero_lambda);
// Pow(A, 1) = A
MATCH_REPLACE(node, PBinOperation(prim::kPrimPow, x, one_const, false), x);
// Pow(exp(A),B) = exp(A*B)
MATCH_REPLACE_LAMBDA(node, PBinOperation(prim::kPrimPow, PUnaryOperation(prim::kPrimExp, x), y, false),
exp_power_lambda);
// Pow(A, 2) = A*A
MATCH_REPLACE_LAMBDA(node, PBinOperation(prim::kPrimPow, x, two_const, false), squre_power_lambda);
// Pow(A, -1) = 1/A
MATCH_REPLACE_LAMBDA(node, PBinOperation(prim::kPrimPow, x, negone_const, false), r_power_lambda);
return nullptr;
}
AnfNodePtr SimplifySqrt(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimSqrt)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y;
auto mul_sqrt_lambda = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAbs), x.GetNode(node)}, node);
return new_cnode;
};
auto square_sqrt_lambda = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAbs), x.GetNode(node)}, node);
return new_cnode;
};
// pattern matcher cannot distinguish the same PatternNode in CaptureNode, so it needs to add judgment
// Sqrt(A*A) = |A|
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, PBinOperation(prim::kPrimMul, x, y)), mul_sqrt_lambda,
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// Sqrt(Square(A)) = |A|
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimSqrt, PUnaryOperation(prim::kPrimSquare, x)),
square_sqrt_lambda);
return nullptr;
}
AnfNodePtr SimplifyRsqrt(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimRsqrt)) {
return nullptr;
}
PatternNode<AnfNodePtr> x;
PConstant<AnfNodePtr> num_one(node, false, 1, true);
PConstant<AnfNodePtr> num_negtwo(node, false, -2, true);
auto power_rsqrt_lambda = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAbs), x.GetNode(node)}, node);
return new_cnode;
};
auto div_rsqrt_lambda = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node);
return new_cnode;
};
// Rsqrt(Pow(A, -2)) = |A|
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimRsqrt, PBinOperation(prim::kPrimPow, x, num_negtwo, false)),
power_rsqrt_lambda);
// Rsqrt(Divide(1, A)) = Sqrt(A)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimRsqrt, PBinOperation(prim::kPrimRealDiv, num_one, x, false)),
div_rsqrt_lambda);
return nullptr;
}
AnfNodePtr SimplifySelect(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimSelect)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y, z;
// select(x,y,y) = y
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, x, y, z), y,
PIsEqual<AnfNodePtr>()(y.GetNode(node), z.GetNode(node)));
return nullptr;
}
AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimMul)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y;
PConstant<AnfNodePtr> const_1(node), const_2(node);
auto const_dup_lambda = [&node, &x, &y, &const_1, &const_2]() -> AnfNodePtr {
auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node);
auto new_rhs = const_1.MulByPatternConst(const_2, x.GetNode(node));
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), new_lhs, new_rhs}, node);
return new_cnode;
};
auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
return new_cnode;
};
auto sqrt_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), node_tmp}, node);
return new_cnode;
};
auto rsqrt_merge_lambda = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReciprocal), x.GetNode(node)}, node);
return new_cnode;
};
auto rsqrt_merge_lambda_2 = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRsqrt), node_tmp}, node);
return new_cnode;
};
auto rsqrt_merge_lambda_3 = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node);
return new_cnode;
};
// (x*C1)*(y*C2) ==> (x*y)*(C1*C2)
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda);
// exp(x)*exp(y) ==> exp(x+y)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda);
// sqrt(x)*sqrt(x) ==> x
MATCH_REPLACE_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), x,
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// sqrt(x)*sqrt(y) ==> sqrt(x*y)
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y),
sqrt_merge_lambda, !PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// rsqrt(x)*rsqrt(x) ==> 1/x
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimRsqrt, x) * PUnaryOperation(prim::kPrimRsqrt, y),
rsqrt_merge_lambda, PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// rsqrt(x)*rsqrt(y) ==> rsqrt(x*y)
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimRsqrt, x) * PUnaryOperation(prim::kPrimRsqrt, y),
rsqrt_merge_lambda_2, !PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// x*rsqrt(x) ==> sqrt(x)
MATCH_REPLACE_LAMBDA_IF(node, x * PUnaryOperation(prim::kPrimRsqrt, y), rsqrt_merge_lambda_3,
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
return nullptr;
}
AnfNodePtr SimplifyDiv(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y, u, v;
PConstant<AnfNodePtr> const_1(node), const_2(node);
PConstant<AnfNodePtr> const_one(node, false, 1);
PConstant<AnfNodePtr> const_one_scalar(node, false, 1, true);
auto div_exp_lambda_1 = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimSub), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
return new_cnode;
};
auto div_exp_lambda_2 = [&node, &x, &y]() -> AnfNodePtr {
auto node_neg = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), y.GetNode(node)}, node);
auto node_exp = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_neg}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_exp}, node);
return new_cnode;
};
auto div_pow_const = [&node, &x, &y, &const_1]() -> AnfNodePtr {
auto new_const = const_1.ValueNodeWithOprations(prim::kPrimNeg);
auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimPow), y.GetNode(node), new_const}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
auto div_sqrt_lambda_1 = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node);
return new_cnode;
};
auto div_sqrt_lambda_2 = [&node, &x, &y]() -> AnfNodePtr {
auto node_rsqrt = NewCNodeWithInfo({NewValueNode(prim::kPrimRsqrt), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_rsqrt}, node);
return new_cnode;
};
auto div_const = [&node, &x, &const_1]() -> AnfNodePtr {
auto new_const = const_1.ValueNodeWithOprations(prim::kPrimReciprocal);
if (new_const == nullptr) {
return nullptr;
}
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_const}, node);
return new_cnode;
};
auto div_rsqrt_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_rsqrt = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_rsqrt}, node);
return new_cnode;
};
auto div_lambda_1 = [&node, &x, &y, &u, &v]() -> AnfNodePtr {
auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), v.GetNode(node)}, node);
auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), u.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, new_rhs}, node);
return new_cnode;
};
auto div_lambda_2 = [&node, &x, &y, &u]() -> AnfNodePtr {
auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), u.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
auto div_lambda_3 = [&node, &x, &u, &v]() -> AnfNodePtr {
auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), v.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, u.GetNode(node)}, node);
return new_cnode;
};
// x/1 ==> x
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarDiv, x, const_one_scalar, false), x);
MATCH_REPLACE(node, x / const_one, x);
// e^x/e^y ==> e^(x-y)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_1);
// x / e^y ==> x * e^(-y)
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_2);
// x / y^const ==> x * y^(-const)
MATCH_REPLACE_LAMBDA(node, x / PBinOperation(prim::kPrimPow, y, const_1), div_pow_const);
// x / sqrt(x) ==> sqrt(x)
MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_1,
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// x / sqrt(y) ==> x * rsqrt(y)
MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_2,
!PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// x / rsqrt(y) ==> x * sqrt(y)
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimRsqrt, y), div_rsqrt_lambda);
// // x / const ==> x * (1/const)
MATCH_REPLACE_LAMBDA(node, x / const_1, div_const);
// (x/y) / (u/v) ==> (x*v) / (y*u)
MATCH_REPLACE_LAMBDA(node, (x / y) / (u / v), div_lambda_1);
// (x/y) / u ==> x / (y*u)
MATCH_REPLACE_LAMBDA(node, (x / y) / u, div_lambda_2);
// x / (u/v) ==> (x*v) / u
MATCH_REPLACE_LAMBDA(node, x / (u / v), div_lambda_3);
return nullptr;
}
#define PERFORM_REPLACE(OldNode, NewNode, Graph, FLAG) \
if ((NewNode) != nullptr) { \
(Graph)->manager()->Replace((OldNode), (NewNode)); \
(FLAG) = true; \
}
AnfNodePtr TrySimplify(const AnfNodePtr &node) {
std::list<std::function<AnfNodePtr(AnfNodePtr)>> SimplifyFuncList = {
SimplifyAdd, SimplifyDiv, SimplifyLog, SimplifyMul, SimplifyNeg,
SimplifyPow, SimplifyRsqrt, SimplifySelect, SimplifySqrt, SimplifySub};
for (auto f : SimplifyFuncList) {
auto ret = f(node);
if (ret != nullptr) {
return ret;
}
}
return nullptr;
}
void InlineSubgraph(const CNodePtr &kernel_node, const FuncGraphPtr &sub_graph, const FuncGraphPtr &main_func_graph) {
AnfNodePtrList ins;
ins.insert(ins.end(), kernel_node->inputs().begin() + 1, kernel_node->inputs().end());
auto out = InlineClone(sub_graph, main_func_graph, ins, kernel_node->input(0)->scope());
auto mng = main_func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
mng->Replace(kernel_node, out);
}
CNodePtr AddIdentityToEmptyPath(const AnfNodePtr &node, const FuncGraphPtr &sub_graph) {
if (node->isa<Parameter>() || node->isa<ValueNode>()) {
auto identity_node = sub_graph->NewCNode({NewValueNode(prim::kPrimIdentity), node});
identity_node->set_abstract(node->abstract());
sub_graph->AddNode(identity_node);
return identity_node;
}
return nullptr;
}
// If the return of the subgraph contains input Parameters or a new ValueNode,
// add identity mapping to them to avoid dealing with empty paths in subgraphs,
// then inline the subgraph into the main graph
bool CheckAndInlineEmptyGraph(const AnfNodePtr &node, const FuncGraphPtr &main_func_graph) {
if (!AnfAlgo::IsGraphKernel(node)) {
MS_LOG(ERROR) << node->ToString() << "is not a graph kernel\n";
return false;
}
auto kernel_node = node->cast<CNodePtr>();
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(kernel_node);
auto mng = sub_graph->manager();
if (mng == nullptr) {
mng = Manage(sub_graph, false);
sub_graph->set_manager(mng);
}
auto sub_return = sub_graph->get_return();
auto pred_node_of_return = sub_return->input(1);
bool ret = false;
if (!IsPrimitiveCNode(pred_node_of_return, prim::kPrimMakeTuple)) { // Single output
auto new_cnode = AddIdentityToEmptyPath(pred_node_of_return, sub_graph);
if (new_cnode != nullptr) {
sub_return->set_input(1, new_cnode);
ret = true;
}
} else { // Multiple output
auto maketuple_node = pred_node_of_return->cast<CNodePtr>();
size_t size_ret = maketuple_node->inputs().size();
size_t empty_path_cnt = 0;
for (size_t i = 1; i < size_ret; i++) {
auto tmp_node = maketuple_node->input(i);
auto new_cnode = AddIdentityToEmptyPath(tmp_node, sub_graph);
if (new_cnode != nullptr) {
maketuple_node->set_input(i, new_cnode);
empty_path_cnt++;
}
}
if (empty_path_cnt == 0) { // normal subgraph
return false;
} else if (empty_path_cnt < size_ret - 1) {
MS_EXCEPTION(NotSupportError);
return false;
} else { // empty subgraph
ret = true;
}
}
if (ret) {
InlineSubgraph(kernel_node, sub_graph, main_func_graph);
}
return ret;
}
AnfNodePtr MatchIdentity(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimIdentity)) {
return nullptr;
}
PatternNode<AnfNodePtr> x;
MATCH_REPLACE(node, PUnaryOperation(prim::kPrimIdentity, x), x);
return nullptr;
}
void EliminateEmptyGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, false);
func_graph->set_manager(mng);
}
bool empty_graph = false;
auto cnodes = func_graph->GetOrderedCnodes();
for (auto cnode : cnodes) {
if (AnfAlgo::IsGraphKernel(cnode)) {
empty_graph = empty_graph || CheckAndInlineEmptyGraph(cnode, func_graph);
}
}
if (empty_graph) {
cnodes = func_graph->GetOrderedCnodes();
for (auto cnode : cnodes) {
auto node = cnode->cast<AnfNodePtr>();
auto new_node = MatchIdentity(node);
if (new_node != nullptr) {
mng->Replace(node, new_node);
}
}
}
}
bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool replaced = false;
for (auto node : func_graph->GetOrderedCnodes()) {
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
mng_sub = Manage(sub_graph, false);
sub_graph->set_manager(mng_sub);
}
bool sub_graph_changed = false;
for (auto node_sub : sub_graph->GetOrderedCnodes()) {
auto new_node = TrySimplify(node_sub);
if (new_node != nullptr) {
sub_graph_changed = true;
PERFORM_REPLACE(node_sub->cast<AnfNodePtr>(), new_node, sub_graph, replaced);
}
}
if (sub_graph_changed) {
ResetKernelInfo(node, AKG_KERNEL);
}
} else {
auto new_node = TrySimplify(node);
PERFORM_REPLACE(node->cast<AnfNodePtr>(), new_node, func_graph, replaced);
}
}
EliminateEmptyGraph(func_graph);
return replaced;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
class ArithmeticSimplify : public Pass {
public:
ArithmeticSimplify() : Pass("arithmetic_simplify") {}
~ArithmeticSimplify() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
using ArithmeticSimplifyPtr = std::shared_ptr<ArithmeticSimplify>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_

View File

@ -713,26 +713,6 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
std::vector<int> need_inline_;
};
// Eliminate the redundant MakeTuple-GetItem operations.
void EliminateTupleGetItem(const FuncGraphPtr &func_graph) {
auto callback = [](const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) return;
for (size_t i = 1; i < cnode->size(); ++i) {
auto getitem = cnode->input(i);
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
auto getitem_cnode = getitem->cast<CNodePtr>();
auto maketuple = getitem_cnode->input(kRealInputNodeIndexInTupleGetItem);
if (!AnfAlgo::CheckPrimitiveType(maketuple, prim::kPrimMakeTuple)) continue;
auto maketuple_cnode = maketuple->cast<CNodePtr>();
int getitem_idx =
GetValue<int>(getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value());
cnode->set_input(i, maketuple_cnode->input(getitem_idx + 1));
}
};
TraverseFuncGraph(func_graph, callback);
}
bool TrySplit(const CNodePtr &sub_root_cnode) {
MS_LOG(INFO) << "Split process node: " << sub_root_cnode->fullname_with_scope();
auto splitter = Splitter::MakeSplitter(sub_root_cnode, std::make_shared<CostModelSplitSchemer>());
@ -761,9 +741,6 @@ bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) {
changed = TrySplit(node) || changed;
}
}
if (changed) {
EliminateTupleGetItem(func_graph);
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
return changed;

View File

@ -43,6 +43,7 @@
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/ms_utils.h"
#include "utils/config_manager.h"
@ -116,7 +117,11 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
// After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::BindValueToGraph>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);

View File

@ -152,6 +152,49 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
mutable AnfNodePtr captured_binop_node_{nullptr};
};
template <typename T>
class PUnaryOperation : public PBase<PUnaryOperation<T> > {
public:
PUnaryOperation(const PrimitivePtr &prim, const T &x) : prim_(prim), x_(x) {}
~PUnaryOperation() = default;
AnfNodePtr GetNode(const AnfNodePtr &node) const {
AnfNodePtrList list = {NewValueNode(prim_), x_.GetNode(node)};
return NewCNode(list, node->func_graph());
}
bool TryCapture_(const AnfNodePtr &node) const {
if (IsPrimitiveCNode(node, prim_)) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (inputs.size() == 2 && x_.TryCapture(inputs[1])) {
captured_unaryop_node_ = node;
return true;
}
}
return false;
}
AnfNodePtr GetOriginalNode() const {
if (captured_unaryop_node_ == nullptr) {
MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
}
return captured_unaryop_node_;
}
void Reset() const {
x_.Reset();
captured_unaryop_node_ = nullptr;
}
using Internal = const PUnaryOperation<T> &;
private:
const PrimitivePtr prim_;
typename T::Internal x_;
mutable AnfNodePtr captured_unaryop_node_{nullptr};
};
///
/// Helper functions to apply a pattern function on all elements of a tuple
///
@ -681,10 +724,74 @@ class PConstant : public PBase<PConstant<T> > {
return new_vnode;
}
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename TD>
TD CalcuConstant(const TD &data, const PrimitivePtr &calcu_type) {
TD tmp_data = data;
if (calcu_type == prim::kPrimReciprocal) {
if (data == 0) {
MS_EXCEPTION(ValueError);
} else {
tmp_data = 1 / data;
}
}
if (calcu_type == prim::kPrimNeg) {
tmp_data = -data;
}
return tmp_data;
}
// calculate const with different operations
AnfNodePtr ValueNodeWithOprations(const PrimitivePtr &calcu_type) {
AnfNodePtr node = this->GetNode(captured_node_);
if (!node->isa<ValueNode>()) {
MS_EXCEPTION(ValueError) << "CalcuValue is trying to use a not ValueNode.";
}
auto value = node->cast<ValueNodePtr>()->value();
if (value->isa<tensor::Tensor>()) {
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) ||
(tensor_type == TypeId::kNumberTypeFloat64)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
}
data2[i] = CalcuConstant(data2[i], calcu_type);
}
}
if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
}
data2[i] = CalcuConstant(data2[i], calcu_type);
}
}
if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
}
data2[i] = CalcuConstant(data2[i], calcu_type);
}
}
return node;
}
return nullptr;
}
enum BinOperator {
ADD = 0,
MULTIPLY,
};
// Support function to add/multiply two constant tensors: partially support broadcasting shapes
template <typename TM>
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size) const {
void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size, BinOperator bin_operator) const {
TM *data_1 = reinterpret_cast<TM *>(in_data_1);
TM *data_2 = reinterpret_cast<TM *>(in_data_2);
TM *data_out = new TM[out_data_size];
@ -700,27 +807,42 @@ class PConstant : public PBase<PConstant<T> > {
}
if (in_data_2_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[0];
if (bin_operator == ADD) {
data_out[i] += data_2[0];
} else {
data_out[i] *= data_2[0];
}
}
} else {
if (in_data_2_size < out_data_size) {
MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size.";
}
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[i];
if (bin_operator == ADD) {
data_out[i] += data_2[i];
} else {
data_out[i] *= data_2[i];
}
}
}
*out_data = reinterpret_cast<void *>(data_out);
return;
}
AnfNodePtr AddByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
AnfNodePtr vnode_1 = this->GetNode(captured_node_);
AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
return CalcConstantTensors(vnode_1, vnode_2, node_3, ADD);
}
AnfNodePtr MulByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
AnfNodePtr vnode_1 = this->GetNode(captured_node_);
AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
return MulConstantTensors(vnode_1, vnode_2, node_3);
return CalcConstantTensors(vnode_1, vnode_2, node_3, MULTIPLY);
}
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const {
AnfNodePtr CalcConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3,
BinOperator bin_operator) const {
if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) ||
(vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) {
return nullptr;
@ -778,21 +900,21 @@ class PConstant : public PBase<PConstant<T> > {
void *data_out = nullptr;
if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) ||
(new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) {
Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
CalcByOperator<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<float *>(data_out);
} else {
if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) {
Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
CalcByOperator<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<double *>(data_out);
} else {
if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) ||
(new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) {
Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
CalcByOperator<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<int *>(data_out);
} else {
@ -833,6 +955,8 @@ class PConstant : public PBase<PConstant<T> > {
// Arithmetic operations
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true);
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false);
BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
// Macros for match and replace
#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
self.sqrt = P.Sqrt()
self.pow = P.Pow()
self.neg = P.Neg()
def construct(self, x, y):
add_res1 = self.add(x, 4)
add_res2 = self.add(add_res1, 5)
sub_res = self.sub(y, 3)
mul_res = self.mul(self.sqrt(add_res2), self.sqrt(sub_res))
div_res = self.div(mul_res, self.sqrt(mul_res))
pow_res = self.pow(y, 2)
neg_res = self.neg(self.neg(pow_res))
return self.add(div_res, neg_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.abs(input_y) + 3
add_res = input_x + 9
sub_res = input_y + (-3)
mul_res = np.sqrt(add_res * sub_res)
div_res = np.sqrt(mul_res)
pow_res = input_y * input_y
neg_res = pow_res
expect = div_res + neg_res
net = Net()
result = net(Tensor(input_x), Tensor(input_y))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res