diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc new file mode 100644 index 00000000000..58e1301c52b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -0,0 +1,556 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/arithmetic_simplify.h" +#include +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/pattern_matcher.h" +#include "frontend/operator/ops.h" +#include "utils/convert_utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { + +AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_node) { + auto func_graph = ori_node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto new_cnode = func_graph->NewCNode(inputs); + new_cnode->set_abstract(ori_node->abstract()); + new_cnode->set_kernel_info(std::make_shared()); + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + ResetKernelInfo(new_cnode, AKG_KERNEL); + } else { + ResetKernelInfo(new_cnode, UNKNOWN_KERNEL_TYPE); + } + func_graph->AddNode(new_cnode); + return new_cnode; +} + +AnfNodePtr SimplifyAdd(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimTensorAdd)) { + return nullptr; + } + PatternNode x, y, z; + PConstant zero_num(node, false, 0); + PConstant zero_scalar(node, false, 0, true); + PConstant any_const(node); + PConstant any_const_2(node); + + auto add_distri_lambda = [&node, &x, &y, &any_const]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), node_tmp, any_const.GetNode(node)}, node); + return new_cnode; + }; + auto add_union_lambda = [&node, &x, &any_const, &any_const_2]() -> AnfNodePtr { + auto new_rhs = any_const.AddByPatternConst(any_const_2, x.GetNode(node)); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node); + return new_cnode; + }; + // A + 0 = A + MATCH_REPLACE(node, x + zero_num, x); + // A*C + B*C = (A + B)*C + MATCH_REPLACE_LAMBDA(node, (x * any_const) + (y * any_const), add_distri_lambda); + // (A + C1) + C2 = A + (C1 + C2) + MATCH_REPLACE_LAMBDA(node, (x + any_const) + any_const_2, add_union_lambda); + // A + (-A) = 0 + MATCH_REPLACE(node, x + PUnaryOperation(prim::kPrimNeg, x), zero_scalar.NewValue()); + return nullptr; +} + +AnfNodePtr SimplifySub(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimSub)) { + return nullptr; + } + PatternNode x; + PConstant zero_num(node, false, 0); + PConstant any_const(node); + auto sub_toadd_lambda = [&node, &x, &any_const]() -> AnfNodePtr { + auto new_rhs = any_const.ValueNodeWithOprations(prim::kPrimNeg); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node); + return new_cnode; + }; + // A - 0 = A + MATCH_REPLACE(node, x - zero_num, x); + // A - const = A + (-const) + MATCH_REPLACE_LAMBDA(node, x - any_const, sub_toadd_lambda); + return nullptr; +} + +AnfNodePtr SimplifyNeg(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimNeg)) { + return nullptr; + } + PatternNode x; + MATCH_REPLACE(node, PUnaryOperation(prim::kPrimNeg, PUnaryOperation(prim::kPrimNeg, x)), x); + return nullptr; +} + +AnfNodePtr SimplifyLog(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimLog)) { + return nullptr; + } + PatternNode x, y; + auto ln_front_lambda = [&node, &x, &y]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAbs), x.GetNode(node)}, node); + auto node_tmp_2 = NewCNodeWithInfo({NewValueNode(prim::kPrimLog), node_tmp}, node); + auto new_cnode = + NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), node_tmp_2}, node->cast()->input(1)); + return new_cnode; + }; + auto sqrt_ln_lambda = [&node, &x]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimLog), x.GetNode(node)}, node); + auto value = MakeValue(std::make_shared(0.5)); + auto tensor_ptr = mindspore::ScalarToTensor(value->cast()); + auto value_node_ptr = MakeValueNode(std::make_shared(tensor_ptr)); + value_node_ptr->set_abstract(node->abstract()); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), value_node_ptr, node_tmp}, node); + return new_cnode; + }; + auto rsqrt_ln_lambda = [&node, &x]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimLog), x.GetNode(node)}, node); + auto value = MakeValue(std::make_shared(-0.5)); + auto tensor_ptr = mindspore::ScalarToTensor(value->cast()); + auto value_node_ptr = MakeValueNode(std::make_shared(tensor_ptr)); + value_node_ptr->set_abstract(node->abstract()); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), value_node_ptr, node_tmp}, node); + return new_cnode; + }; + // Ln(Exp(A)) = A + MATCH_REPLACE(node, PUnaryOperation(prim::kPrimLog, PUnaryOperation(prim::kPrimExp, x)), x); + // Ln(Pow(A,B)) = B*Ln(Abs(A)) + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimLog, PBinOperation(prim::kPrimPow, x, y, false)), + ln_front_lambda); + // Ln(Sqrt(A)) = 0.5*Ln(A) + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimLog, PUnaryOperation(prim::kPrimSqrt, x)), sqrt_ln_lambda); + // Ln(Rqrt(A)) = -0.5*Ln(A) + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimLog, PUnaryOperation(prim::kPrimRsqrt, x)), rsqrt_ln_lambda); + return nullptr; +} + +AnfNodePtr SimplifyPow(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimPow)) { + return nullptr; + } + PatternNode x, y; + PConstant zero_num(node, false, 0); + PConstant one_const(node, false, 1); + PConstant two_const(node, false, 2); + PConstant negone_const(node, false, -1); + auto pow_zero_lambda = [&node]() -> AnfNodePtr { + auto value = MakeValue(std::make_shared(1)); + auto tensor_ptr = mindspore::ScalarToTensor(value->cast()); + auto value_node_ptr = MakeValueNode(std::make_shared(tensor_ptr)); + value_node_ptr->set_abstract(node->abstract()); + return value_node_ptr; + }; + auto exp_power_lambda = [&node, &x, &y]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node->cast()->input(1)); + return new_cnode; + }; + auto squre_power_lambda = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), x.GetNode(node)}, node); + return new_cnode; + }; + auto r_power_lambda = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReciprocal), x.GetNode(node)}, node); + return new_cnode; + }; + // Pow(A, 0) = 1 + MATCH_REPLACE_LAMBDA(node, PBinOperation(prim::kPrimPow, x, zero_num, false), pow_zero_lambda); + // Pow(A, 1) = A + MATCH_REPLACE(node, PBinOperation(prim::kPrimPow, x, one_const, false), x); + // Pow(exp(A),B) = exp(A*B) + MATCH_REPLACE_LAMBDA(node, PBinOperation(prim::kPrimPow, PUnaryOperation(prim::kPrimExp, x), y, false), + exp_power_lambda); + // Pow(A, 2) = A*A + MATCH_REPLACE_LAMBDA(node, PBinOperation(prim::kPrimPow, x, two_const, false), squre_power_lambda); + // Pow(A, -1) = 1/A + MATCH_REPLACE_LAMBDA(node, PBinOperation(prim::kPrimPow, x, negone_const, false), r_power_lambda); + return nullptr; +} + +AnfNodePtr SimplifySqrt(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimSqrt)) { + return nullptr; + } + PatternNode x, y; + auto mul_sqrt_lambda = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAbs), x.GetNode(node)}, node); + return new_cnode; + }; + auto square_sqrt_lambda = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAbs), x.GetNode(node)}, node); + return new_cnode; + }; + // pattern matcher cannot distinguish the same PatternNode in CaptureNode, so it needs to add judgment + // Sqrt(A*A) = |A| + MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, PBinOperation(prim::kPrimMul, x, y)), mul_sqrt_lambda, + PIsEqual()(x.GetNode(node), y.GetNode(node))); + // Sqrt(Square(A)) = |A| + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimSqrt, PUnaryOperation(prim::kPrimSquare, x)), + square_sqrt_lambda); + return nullptr; +} + +AnfNodePtr SimplifyRsqrt(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimRsqrt)) { + return nullptr; + } + PatternNode x; + PConstant num_one(node, false, 1, true); + PConstant num_negtwo(node, false, -2, true); + auto power_rsqrt_lambda = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAbs), x.GetNode(node)}, node); + return new_cnode; + }; + auto div_rsqrt_lambda = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node); + return new_cnode; + }; + // Rsqrt(Pow(A, -2)) = |A| + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimRsqrt, PBinOperation(prim::kPrimPow, x, num_negtwo, false)), + power_rsqrt_lambda); + // Rsqrt(Divide(1, A)) = Sqrt(A) + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimRsqrt, PBinOperation(prim::kPrimRealDiv, num_one, x, false)), + div_rsqrt_lambda); + return nullptr; +} + +AnfNodePtr SimplifySelect(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimSelect)) { + return nullptr; + } + PatternNode x, y, z; + // select(x,y,y) = y + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, x, y, z), y, + PIsEqual()(y.GetNode(node), z.GetNode(node))); + return nullptr; +} + +AnfNodePtr SimplifyMul(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMul)) { + return nullptr; + } + PatternNode x, y; + PConstant const_1(node), const_2(node); + + auto const_dup_lambda = [&node, &x, &y, &const_1, &const_2]() -> AnfNodePtr { + auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node); + auto new_rhs = const_1.MulByPatternConst(const_2, x.GetNode(node)); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), new_lhs, new_rhs}, node); + return new_cnode; + }; + auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node); + return new_cnode; + }; + auto sqrt_merge_lambda = [&node, &x, &y]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), node_tmp}, node); + return new_cnode; + }; + auto rsqrt_merge_lambda = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReciprocal), x.GetNode(node)}, node); + return new_cnode; + }; + auto rsqrt_merge_lambda_2 = [&node, &x, &y]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRsqrt), node_tmp}, node); + return new_cnode; + }; + auto rsqrt_merge_lambda_3 = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node); + return new_cnode; + }; + // (x*C1)*(y*C2) ==> (x*y)*(C1*C2) + MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda); + // exp(x)*exp(y) ==> exp(x+y) + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda); + // sqrt(x)*sqrt(x) ==> x + MATCH_REPLACE_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), x, + PIsEqual()(x.GetNode(node), y.GetNode(node))); + // sqrt(x)*sqrt(y) ==> sqrt(x*y) + MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), + sqrt_merge_lambda, !PIsEqual()(x.GetNode(node), y.GetNode(node))); + // rsqrt(x)*rsqrt(x) ==> 1/x + MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimRsqrt, x) * PUnaryOperation(prim::kPrimRsqrt, y), + rsqrt_merge_lambda, PIsEqual()(x.GetNode(node), y.GetNode(node))); + // rsqrt(x)*rsqrt(y) ==> rsqrt(x*y) + MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimRsqrt, x) * PUnaryOperation(prim::kPrimRsqrt, y), + rsqrt_merge_lambda_2, !PIsEqual()(x.GetNode(node), y.GetNode(node))); + // x*rsqrt(x) ==> sqrt(x) + MATCH_REPLACE_LAMBDA_IF(node, x * PUnaryOperation(prim::kPrimRsqrt, y), rsqrt_merge_lambda_3, + PIsEqual()(x.GetNode(node), y.GetNode(node))); + return nullptr; +} + +AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) { + return nullptr; + } + PatternNode x, y, u, v; + PConstant const_1(node), const_2(node); + PConstant const_one(node, false, 1); + PConstant const_one_scalar(node, false, 1, true); + + auto div_exp_lambda_1 = [&node, &x, &y]() -> AnfNodePtr { + auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimSub), x.GetNode(node), y.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node); + return new_cnode; + }; + auto div_exp_lambda_2 = [&node, &x, &y]() -> AnfNodePtr { + auto node_neg = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), y.GetNode(node)}, node); + auto node_exp = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_neg}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_exp}, node); + return new_cnode; + }; + auto div_pow_const = [&node, &x, &y, &const_1]() -> AnfNodePtr { + auto new_const = const_1.ValueNodeWithOprations(prim::kPrimNeg); + auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimPow), y.GetNode(node), new_const}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node); + return new_cnode; + }; + auto div_sqrt_lambda_1 = [&node, &x]() -> AnfNodePtr { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node); + return new_cnode; + }; + auto div_sqrt_lambda_2 = [&node, &x, &y]() -> AnfNodePtr { + auto node_rsqrt = NewCNodeWithInfo({NewValueNode(prim::kPrimRsqrt), y.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_rsqrt}, node); + return new_cnode; + }; + auto div_const = [&node, &x, &const_1]() -> AnfNodePtr { + auto new_const = const_1.ValueNodeWithOprations(prim::kPrimReciprocal); + if (new_const == nullptr) { + return nullptr; + } + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_const}, node); + return new_cnode; + }; + auto div_rsqrt_lambda = [&node, &x, &y]() -> AnfNodePtr { + auto node_rsqrt = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), y.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_rsqrt}, node); + return new_cnode; + }; + auto div_lambda_1 = [&node, &x, &y, &u, &v]() -> AnfNodePtr { + auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), v.GetNode(node)}, node); + auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), u.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, new_rhs}, node); + return new_cnode; + }; + auto div_lambda_2 = [&node, &x, &y, &u]() -> AnfNodePtr { + auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), u.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node); + return new_cnode; + }; + auto div_lambda_3 = [&node, &x, &u, &v]() -> AnfNodePtr { + auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), v.GetNode(node)}, node); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, u.GetNode(node)}, node); + return new_cnode; + }; + // x/1 ==> x + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarDiv, x, const_one_scalar, false), x); + MATCH_REPLACE(node, x / const_one, x); + // e^x/e^y ==> e^(x-y) + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_1); + // x / e^y ==> x * e^(-y) + MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_2); + // x / y^const ==> x * y^(-const) + MATCH_REPLACE_LAMBDA(node, x / PBinOperation(prim::kPrimPow, y, const_1), div_pow_const); + // x / sqrt(x) ==> sqrt(x) + MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_1, + PIsEqual()(x.GetNode(node), y.GetNode(node))); + // x / sqrt(y) ==> x * rsqrt(y) + MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_2, + !PIsEqual()(x.GetNode(node), y.GetNode(node))); + // x / rsqrt(y) ==> x * sqrt(y) + MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimRsqrt, y), div_rsqrt_lambda); + // // x / const ==> x * (1/const) + MATCH_REPLACE_LAMBDA(node, x / const_1, div_const); + // (x/y) / (u/v) ==> (x*v) / (y*u) + MATCH_REPLACE_LAMBDA(node, (x / y) / (u / v), div_lambda_1); + // (x/y) / u ==> x / (y*u) + MATCH_REPLACE_LAMBDA(node, (x / y) / u, div_lambda_2); + // x / (u/v) ==> (x*v) / u + MATCH_REPLACE_LAMBDA(node, x / (u / v), div_lambda_3); + return nullptr; +} + +#define PERFORM_REPLACE(OldNode, NewNode, Graph, FLAG) \ + if ((NewNode) != nullptr) { \ + (Graph)->manager()->Replace((OldNode), (NewNode)); \ + (FLAG) = true; \ + } + +AnfNodePtr TrySimplify(const AnfNodePtr &node) { + std::list> SimplifyFuncList = { + SimplifyAdd, SimplifyDiv, SimplifyLog, SimplifyMul, SimplifyNeg, + SimplifyPow, SimplifyRsqrt, SimplifySelect, SimplifySqrt, SimplifySub}; + for (auto f : SimplifyFuncList) { + auto ret = f(node); + if (ret != nullptr) { + return ret; + } + } + return nullptr; +} + +void InlineSubgraph(const CNodePtr &kernel_node, const FuncGraphPtr &sub_graph, const FuncGraphPtr &main_func_graph) { + AnfNodePtrList ins; + ins.insert(ins.end(), kernel_node->inputs().begin() + 1, kernel_node->inputs().end()); + auto out = InlineClone(sub_graph, main_func_graph, ins, kernel_node->input(0)->scope()); + auto mng = main_func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + mng->Replace(kernel_node, out); +} + +CNodePtr AddIdentityToEmptyPath(const AnfNodePtr &node, const FuncGraphPtr &sub_graph) { + if (node->isa() || node->isa()) { + auto identity_node = sub_graph->NewCNode({NewValueNode(prim::kPrimIdentity), node}); + identity_node->set_abstract(node->abstract()); + sub_graph->AddNode(identity_node); + return identity_node; + } + return nullptr; +} + +// If the return of the subgraph contains input Parameters or a new ValueNode, +// add identity mapping to them to avoid dealing with empty paths in subgraphs, +// then inline the subgraph into the main graph +bool CheckAndInlineEmptyGraph(const AnfNodePtr &node, const FuncGraphPtr &main_func_graph) { + if (!AnfAlgo::IsGraphKernel(node)) { + MS_LOG(ERROR) << node->ToString() << "is not a graph kernel\n"; + return false; + } + auto kernel_node = node->cast(); + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(kernel_node); + auto mng = sub_graph->manager(); + if (mng == nullptr) { + mng = Manage(sub_graph, false); + sub_graph->set_manager(mng); + } + auto sub_return = sub_graph->get_return(); + auto pred_node_of_return = sub_return->input(1); + bool ret = false; + if (!IsPrimitiveCNode(pred_node_of_return, prim::kPrimMakeTuple)) { // Single output + auto new_cnode = AddIdentityToEmptyPath(pred_node_of_return, sub_graph); + if (new_cnode != nullptr) { + sub_return->set_input(1, new_cnode); + ret = true; + } + } else { // Multiple output + auto maketuple_node = pred_node_of_return->cast(); + size_t size_ret = maketuple_node->inputs().size(); + size_t empty_path_cnt = 0; + for (size_t i = 1; i < size_ret; i++) { + auto tmp_node = maketuple_node->input(i); + auto new_cnode = AddIdentityToEmptyPath(tmp_node, sub_graph); + if (new_cnode != nullptr) { + maketuple_node->set_input(i, new_cnode); + empty_path_cnt++; + } + } + if (empty_path_cnt == 0) { // normal subgraph + return false; + } else if (empty_path_cnt < size_ret - 1) { + MS_EXCEPTION(NotSupportError); + return false; + } else { // empty subgraph + ret = true; + } + } + if (ret) { + InlineSubgraph(kernel_node, sub_graph, main_func_graph); + } + return ret; +} + +AnfNodePtr MatchIdentity(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimIdentity)) { + return nullptr; + } + PatternNode x; + MATCH_REPLACE(node, PUnaryOperation(prim::kPrimIdentity, x), x); + return nullptr; +} + +void EliminateEmptyGraph(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, false); + func_graph->set_manager(mng); + } + bool empty_graph = false; + auto cnodes = func_graph->GetOrderedCnodes(); + for (auto cnode : cnodes) { + if (AnfAlgo::IsGraphKernel(cnode)) { + empty_graph = empty_graph || CheckAndInlineEmptyGraph(cnode, func_graph); + } + } + if (empty_graph) { + cnodes = func_graph->GetOrderedCnodes(); + for (auto cnode : cnodes) { + auto node = cnode->cast(); + auto new_node = MatchIdentity(node); + if (new_node != nullptr) { + mng->Replace(node, new_node); + } + } + } +} + +bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + bool replaced = false; + for (auto node : func_graph->GetOrderedCnodes()) { + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + auto mng_sub = sub_graph->manager(); + if (mng_sub == nullptr) { + mng_sub = Manage(sub_graph, false); + sub_graph->set_manager(mng_sub); + } + bool sub_graph_changed = false; + for (auto node_sub : sub_graph->GetOrderedCnodes()) { + auto new_node = TrySimplify(node_sub); + if (new_node != nullptr) { + sub_graph_changed = true; + PERFORM_REPLACE(node_sub->cast(), new_node, sub_graph, replaced); + } + } + if (sub_graph_changed) { + ResetKernelInfo(node, AKG_KERNEL); + } + } else { + auto new_node = TrySimplify(node); + PERFORM_REPLACE(node->cast(), new_node, func_graph, replaced); + } + } + EliminateEmptyGraph(func_graph); + return replaced; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.h new file mode 100644 index 00000000000..46ed05961c4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "ir/func_graph.h" + +namespace mindspore { +namespace opt { +class ArithmeticSimplify : public Pass { + public: + ArithmeticSimplify() : Pass("arithmetic_simplify") {} + ~ArithmeticSimplify() override = default; + bool Run(const FuncGraphPtr &func_graph) override; +}; +using ArithmeticSimplifyPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc index ff8d27c0ed2..913314d1614 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc @@ -713,26 +713,6 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { std::vector need_inline_; }; -// Eliminate the redundant MakeTuple-GetItem operations. -void EliminateTupleGetItem(const FuncGraphPtr &func_graph) { - auto callback = [](const AnfNodePtr &node) { - auto cnode = node->cast(); - if (cnode == nullptr) return; - for (size_t i = 1; i < cnode->size(); ++i) { - auto getitem = cnode->input(i); - if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; - auto getitem_cnode = getitem->cast(); - auto maketuple = getitem_cnode->input(kRealInputNodeIndexInTupleGetItem); - if (!AnfAlgo::CheckPrimitiveType(maketuple, prim::kPrimMakeTuple)) continue; - auto maketuple_cnode = maketuple->cast(); - int getitem_idx = - GetValue(getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem)->cast()->value()); - cnode->set_input(i, maketuple_cnode->input(getitem_idx + 1)); - } - }; - TraverseFuncGraph(func_graph, callback); -} - bool TrySplit(const CNodePtr &sub_root_cnode) { MS_LOG(INFO) << "Split process node: " << sub_root_cnode->fullname_with_scope(); auto splitter = Splitter::MakeSplitter(sub_root_cnode, std::make_shared()); @@ -761,9 +741,6 @@ bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) { changed = TrySplit(node) || changed; } } - if (changed) { - EliminateTupleGetItem(func_graph); - } mng->RemoveRoots(); mng->KeepRoots({func_graph}); return changed; diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index e3456a16990..551561a2347 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -43,6 +43,7 @@ #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" +#include "backend/optimizer/graph_kernel/arithmetic_simplify.h" #include "runtime/device/kernel_runtime_manager.h" #include "utils/ms_utils.h" #include "utils/config_manager.h" @@ -116,7 +117,11 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + // After Simplify and Splitter, a lot of redundant getitem/maketuple + // will be exposed, use GetitemTuple Pass to delete them. + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 61b963ccb79..eae586bc204 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -152,6 +152,49 @@ class PBinOperation : public PBase > { mutable AnfNodePtr captured_binop_node_{nullptr}; }; +template +class PUnaryOperation : public PBase > { + public: + PUnaryOperation(const PrimitivePtr &prim, const T &x) : prim_(prim), x_(x) {} + ~PUnaryOperation() = default; + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + AnfNodePtrList list = {NewValueNode(prim_), x_.GetNode(node)}; + return NewCNode(list, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (IsPrimitiveCNode(node, prim_)) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if (inputs.size() == 2 && x_.TryCapture(inputs[1])) { + captured_unaryop_node_ = node; + return true; + } + } + return false; + } + + AnfNodePtr GetOriginalNode() const { + if (captured_unaryop_node_ == nullptr) { + MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; + } + return captured_unaryop_node_; + } + + void Reset() const { + x_.Reset(); + captured_unaryop_node_ = nullptr; + } + + using Internal = const PUnaryOperation &; + + private: + const PrimitivePtr prim_; + typename T::Internal x_; + mutable AnfNodePtr captured_unaryop_node_{nullptr}; +}; + /// /// Helper functions to apply a pattern function on all elements of a tuple /// @@ -681,10 +724,74 @@ class PConstant : public PBase > { return new_vnode; } - // Support function to multiply two constant tensors: partially support broadcasting shapes + template + TD CalcuConstant(const TD &data, const PrimitivePtr &calcu_type) { + TD tmp_data = data; + if (calcu_type == prim::kPrimReciprocal) { + if (data == 0) { + MS_EXCEPTION(ValueError); + } else { + tmp_data = 1 / data; + } + } + if (calcu_type == prim::kPrimNeg) { + tmp_data = -data; + } + return tmp_data; + } + + // calculate const with different operations + AnfNodePtr ValueNodeWithOprations(const PrimitivePtr &calcu_type) { + AnfNodePtr node = this->GetNode(captured_node_); + if (!node->isa()) { + MS_EXCEPTION(ValueError) << "CalcuValue is trying to use a not ValueNode."; + } + auto value = node->cast()->value(); + if (value->isa()) { + tensor::TensorPtr tensor_ptr = dyn_cast(value); + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) || + (tensor_type == TypeId::kNumberTypeFloat64)) { + float *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) { + return nullptr; + } + data2[i] = CalcuConstant(data2[i], calcu_type); + } + } + if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { + int *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) { + return nullptr; + } + data2[i] = CalcuConstant(data2[i], calcu_type); + } + } + if (tensor_type == TypeId::kNumberTypeFloat64) { + double *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) { + return nullptr; + } + data2[i] = CalcuConstant(data2[i], calcu_type); + } + } + return node; + } + return nullptr; + } + + enum BinOperator { + ADD = 0, + MULTIPLY, + }; + + // Support function to add/multiply two constant tensors: partially support broadcasting shapes template - void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, - int out_data_size) const { + void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, + int out_data_size, BinOperator bin_operator) const { TM *data_1 = reinterpret_cast(in_data_1); TM *data_2 = reinterpret_cast(in_data_2); TM *data_out = new TM[out_data_size]; @@ -700,27 +807,42 @@ class PConstant : public PBase > { } if (in_data_2_size == 1) { for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[0]; + if (bin_operator == ADD) { + data_out[i] += data_2[0]; + } else { + data_out[i] *= data_2[0]; + } } } else { if (in_data_2_size < out_data_size) { MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size."; } for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[i]; + if (bin_operator == ADD) { + data_out[i] += data_2[i]; + } else { + data_out[i] *= data_2[i]; + } } } *out_data = reinterpret_cast(data_out); return; } + AnfNodePtr AddByPatternConst(const PConstant &vpnode_2, const AnfNodePtr &node_3) const { + AnfNodePtr vnode_1 = this->GetNode(captured_node_); + AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_); + return CalcConstantTensors(vnode_1, vnode_2, node_3, ADD); + } + AnfNodePtr MulByPatternConst(const PConstant &vpnode_2, const AnfNodePtr &node_3) const { AnfNodePtr vnode_1 = this->GetNode(captured_node_); AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_); - return MulConstantTensors(vnode_1, vnode_2, node_3); + return CalcConstantTensors(vnode_1, vnode_2, node_3, MULTIPLY); } - AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const { + AnfNodePtr CalcConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3, + BinOperator bin_operator) const { if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { return nullptr; @@ -778,21 +900,21 @@ class PConstant : public PBase > { void *data_out = nullptr; if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) || (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); + CalcByOperator(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator); ret = memcpy_s(data, mem_size, data_out, mem_size); delete[] reinterpret_cast(data_out); } else { if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); + CalcByOperator(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator); ret = memcpy_s(data, mem_size, data_out, mem_size); delete[] reinterpret_cast(data_out); } else { if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) || (new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); + CalcByOperator(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator); ret = memcpy_s(data, mem_size, data_out, mem_size); delete[] reinterpret_cast(data_out); } else { @@ -833,6 +955,8 @@ class PConstant : public PBase > { // Arithmetic operations BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true); BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); +BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false); +BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false); // Macros for match and replace #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ diff --git a/tests/st/ops/graph_kernel/test_simplify.py b/tests/st/ops/graph_kernel/test_simplify.py new file mode 100644 index 00000000000..57aaa205e21 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_simplify.py @@ -0,0 +1,67 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations as P + +context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") + + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + self.sqrt = P.Sqrt() + self.pow = P.Pow() + self.neg = P.Neg() + + def construct(self, x, y): + add_res1 = self.add(x, 4) + add_res2 = self.add(add_res1, 5) + sub_res = self.sub(y, 3) + mul_res = self.mul(self.sqrt(add_res2), self.sqrt(sub_res)) + div_res = self.div(mul_res, self.sqrt(mul_res)) + pow_res = self.pow(y, 2) + neg_res = self.neg(self.neg(pow_res)) + return self.add(div_res, neg_res) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_basic(): + input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) + input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) + input_y = np.abs(input_y) + 3 + add_res = input_x + 9 + sub_res = input_y + (-3) + mul_res = np.sqrt(add_res * sub_res) + div_res = np.sqrt(mul_res) + pow_res = input_y * input_y + neg_res = pow_res + expect = div_res + neg_res + + net = Net() + result = net(Tensor(input_x), Tensor(input_y)) + + res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) + assert res