diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc index fdd390677af..4fa4b5e7b04 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc @@ -14,38 +14,39 @@ * limitations under the License. */ #include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h" -#include +#include #include "backend/optimizer/common/helper.h" #include "backend/session/anf_runtime_algorithm.h" #include "utils/utils.h" namespace mindspore { namespace opt { -namespace { -constexpr size_t kMatMulInputIndex = 1; -constexpr size_t kBiasInputIndex = 2; -} // namespace - const BaseRef MatmulBiasaddFusion::DefinePattern() const { - VarPtr X0 = std::make_shared(); - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - const auto prim_bias_add = std::make_shared(kBiasAddOpName); - return VectorRef({prim_bias_add, VectorRef({prim::kPrimMatMul, X0, X1}), X2}); + VectorRef matmul({matmul_var_, x0_, x1_}); + VectorRef pattern({prim::kPrimBiasAdd, matmul, x2_}); + return pattern; } -const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { +const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - CheckCNodeInputSize(cnode, kBiasAddInputNum); - AnfNodePtr matmul = cnode->input(kMatMulInputIndex); - MS_EXCEPTION_IF_NULL(matmul); - auto matmul_cnode = matmul->cast(); - MS_EXCEPTION_IF_NULL(matmul_cnode); - matmul_cnode->add_input(cnode->input(kBiasInputIndex)); - AnfAlgo::SetNodeAttr(kAttrHasBias, MakeValue(true), matmul); - return matmul; + MS_EXCEPTION_IF_NULL(graph); + std::vector inputs; + inputs.emplace_back(NewValueNode(std::make_shared(prim::kPrimMatMul->name()))); + inputs.emplace_back(GetAnfNodeByVar(equiv, x0_)); + inputs.emplace_back(GetAnfNodeByVar(equiv, x1_)); + inputs.emplace_back(GetAnfNodeByVar(equiv, x2_)); + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(node->scope()); + new_node->set_abstract(node->abstract()); + + auto matmul = GetAnfNodeByVar(equiv, matmul_var_); + if (matmul == nullptr || !matmul->isa()) { + MS_LOG(EXCEPTION) << "Get CNode MatMul failed!"; + } + AnfAlgo::CopyNodeAttrs(matmul, new_node); + return new_node; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h index a2f9e8b0ffa..23ad7ff66b6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h @@ -16,17 +16,29 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ +#include #include "backend/optimizer/common/optimizer.h" namespace mindspore { namespace opt { class MatmulBiasaddFusion : public PatternProcessPass { public: - explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {} + explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) { + x0_ = std::make_shared(); + x1_ = std::make_shared(); + x2_ = std::make_shared(); + matmul_var_ = std::make_shared(std::make_shared(prim::kPrimMatMul->name())); + } ~MatmulBiasaddFusion() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x0_; + VarPtr x1_; + VarPtr x2_; + VarPtr matmul_var_; }; } // namespace opt } // namespace mindspore