diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc index 711ffff612c..f152f187317 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc @@ -24,40 +24,57 @@ #include "pre_activate/common/helper.h" namespace mindspore { -namespace opt { -const BaseRef MulAddFusion::DefinePattern() const { - VarPtr mul_x_ = std::make_shared(); - VarPtr mul_y_ = std::make_shared(); - VarPtr add_y_ = std::make_shared(); +bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(add); - VectorRef mul({prim::kPrimMul, mul_x_, mul_y_}); - VectorRef add({prim::kPrimTensorAdd, mul, add_y_}); - return add; + for (size_t index = 1; index < add->size(); ++index) { + auto input = add->input(index); + MS_EXCEPTION_IF_NULL(input); + if (input->isa()) { + auto cnode = input->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { + if (!opt::IsUsedByOthers(graph, cnode)) { + *mul = cnode; + *mul_index = index; + return true; + } + } + } + } + return false; } -const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { - if (graph == nullptr || node == nullptr || equiv == nullptr) { +namespace opt { +const BaseRef MulAddFusion::DefinePattern() const { + VarPtr x = std::make_shared(); + VarPtr y = std::make_shared(); + VectorRef pattern({prim::kPrimTensorAdd, x, y}); + return pattern; +} + +const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + if (graph == nullptr || node == nullptr) { return nullptr; } auto add = node->cast(); if (add == nullptr || add->inputs().size() != kAddInputNum) { return nullptr; } - auto mul_anf = add->input(1); - if (mul_anf == nullptr) { - return nullptr; - } - auto mul = mul_anf->cast(); - if (mul == nullptr || mul->inputs().size() != kMulInputNum) { - return nullptr; - } - if (IsUsedByOthers(graph, mul)) { - MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse"; + CNodePtr mul = nullptr; + size_t mul_index = 0; + if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) { + MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs"; return nullptr; } auto prim = std::make_shared(kFusedMulAddOpName); - std::vector inputs = {NewValueNode(prim), mul->input(1), mul->input(2), add->input(2)}; + std::vector inputs = {NewValueNode(prim)}; + for (size_t index = 1; index < mul->size(); ++index) { + inputs.push_back(mul->input(index)); + } + inputs.push_back(add->input(add->size() - mul_index)); auto fusion_node = graph->NewCNode(inputs); fusion_node->set_scope(add->scope()); fusion_node->set_abstract(add->abstract()); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc index 7477d6252cd..50221fca194 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc @@ -28,8 +28,28 @@ class TestHWMulAddFusion : public BackendCommon { UT::PyFuncGraphFetcher get_py_fun_; }; -TEST_F(TestHWMulAddFusion, test_mul_add_fusion) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before"); +TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before1"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 3; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before2"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_add_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_add_fusion_test.py index 83a62233bfb..f3100b474aa 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_add_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_add_fusion_test.py @@ -21,7 +21,6 @@ fused_mul_add = Primitive('FusedMulAdd') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') - class FnDict: def __init__(self): self.fnDict = {} @@ -32,16 +31,21 @@ class FnDict: def __getitem__(self, name): return self.fnDict[name] - def test_mul_add_fusion(tag): fns = FnDict() @fns - def before(x, y, z): + def before1(x, y, z): res = mul(x, y) res = add(res, z) return res + @fns + def before2(x, y, z): + res = mul(x, y) + res = add(z, res) + return res + @fns def after(x, y, z): res = fused_mul_add(x, y, z)