!30849 add transpose_x2 check in MatmulConfusionTranposeFusionPass

Merge pull request !30849 from yuchaojie/r1.6_fix
This commit is contained in:
i-robot 2022-03-04 15:44:08 +00:00 committed by Gitee
commit 946ac31814
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 1 additions and 40 deletions

View File

@ -23,46 +23,7 @@
namespace mindspore {
namespace opt {
namespace {
constexpr auto kAttrTransposeX1 = "transpose_x1";
constexpr auto kAttrTransposeX2 = "transpose_x2";
struct WrongCase {
std::vector<size_t> matmul_input0_shape;
std::vector<size_t> matmul_input1_shape;
std::vector<size_t> transpose_output_shape;
bool transpose_x1;
bool transpose_x2;
};
bool CheckWrongShape(const AnfNodePtr &matmul, const AnfNodePtr &confusion_transpose) {
std::vector<WrongCase> wrong_cases;
// add wrong cases
WrongCase wrong_case1;
wrong_case1.matmul_input0_shape = {128, 1024};
wrong_case1.matmul_input1_shape = {1024, 1024};
wrong_case1.transpose_output_shape = {1, 16, 128, 64};
wrong_case1.transpose_x1 = false;
wrong_case1.transpose_x2 = true;
wrong_cases.push_back(std::move(wrong_case1));
// get node shape
auto matmul_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(matmul, 0);
auto matmul_input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(matmul, 1);
auto transpose_output_shape = AnfAlgo::GetOutputInferShape(confusion_transpose, 0);
auto transpose_x1 = AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX1);
auto transpose_x2 = AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX2);
// check
return std::any_of(wrong_cases.begin(), wrong_cases.end(),
[matmul_input0_shape, matmul_input1_shape, transpose_output_shape, transpose_x1,
transpose_x2](WrongCase wrong_case) {
return wrong_case.matmul_input0_shape == matmul_input0_shape &&
wrong_case.matmul_input1_shape == matmul_input1_shape &&
wrong_case.transpose_output_shape == transpose_output_shape &&
wrong_case.transpose_x1 == transpose_x1 && wrong_case.transpose_x2 == transpose_x2;
});
}
} // namespace
void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode,
@ -74,7 +35,7 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
MS_EXCEPTION_IF_NULL(matmul);
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {
if (CheckWrongShape(matmul, cnode)) {
if (AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX2) == true) {
return;
}
mindspore::HashSet<AnfNodePtr> record{cnode, matmul};