From d646cf1ff65a343dc890a81211108362a4485567 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Fri, 4 Mar 2022 17:19:38 +0800 Subject: [PATCH] add transpose_x2 check in MatmulConfusionTranposeFusionPass --- .../matmul_confusiontranspose_fusion_pass.cc | 41 +------------------ 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc index 6122537f7ee..88d56c7afde 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc @@ -23,46 +23,7 @@ namespace mindspore { namespace opt { namespace { -constexpr auto kAttrTransposeX1 = "transpose_x1"; constexpr auto kAttrTransposeX2 = "transpose_x2"; - -struct WrongCase { - std::vector matmul_input0_shape; - std::vector matmul_input1_shape; - std::vector transpose_output_shape; - bool transpose_x1; - bool transpose_x2; -}; - -bool CheckWrongShape(const AnfNodePtr &matmul, const AnfNodePtr &confusion_transpose) { - std::vector wrong_cases; - - // add wrong cases - WrongCase wrong_case1; - wrong_case1.matmul_input0_shape = {128, 1024}; - wrong_case1.matmul_input1_shape = {1024, 1024}; - wrong_case1.transpose_output_shape = {1, 16, 128, 64}; - wrong_case1.transpose_x1 = false; - wrong_case1.transpose_x2 = true; - wrong_cases.push_back(std::move(wrong_case1)); - - // get node shape - auto matmul_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(matmul, 0); - auto matmul_input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(matmul, 1); - auto transpose_output_shape = AnfAlgo::GetOutputInferShape(confusion_transpose, 0); - auto transpose_x1 = AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX1); - auto transpose_x2 = AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX2); - - // check - return std::any_of(wrong_cases.begin(), wrong_cases.end(), - [matmul_input0_shape, matmul_input1_shape, transpose_output_shape, transpose_x1, - transpose_x2](WrongCase wrong_case) { - return wrong_case.matmul_input0_shape == matmul_input0_shape && - wrong_case.matmul_input1_shape == matmul_input1_shape && - wrong_case.transpose_output_shape == transpose_output_shape && - wrong_case.transpose_x1 == transpose_x1 && wrong_case.transpose_x2 == transpose_x2; - }); -} } // namespace void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode, @@ -74,7 +35,7 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode MS_EXCEPTION_IF_NULL(matmul); if (matmul->isa() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) || AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) { - if (CheckWrongShape(matmul, cnode)) { + if (AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX2) == true) { return; } mindspore::HashSet record{cnode, matmul};