add case check for MatmulConfusionTranposeFusionPass

This commit is contained in:
yuchaojie 2022-03-03 15:44:50 +08:00
parent f1c1acd681
commit ffb9bbbbd8
1 changed files with 46 additions and 0 deletions

View File

@ -22,6 +22,49 @@
namespace mindspore {
namespace opt {
namespace {
constexpr auto kAttrTransposeX1 = "transpose_x1";
constexpr auto kAttrTransposeX2 = "transpose_x2";
struct WrongCase {
std::vector<size_t> matmul_input0_shape;
std::vector<size_t> matmul_input1_shape;
std::vector<size_t> transpose_output_shape;
bool transpose_x1;
bool transpose_x2;
};
bool CheckWrongShape(const AnfNodePtr &matmul, const AnfNodePtr &confusion_transpose) {
std::vector<WrongCase> wrong_cases;
// add wrong cases
WrongCase wrong_case1;
wrong_case1.matmul_input0_shape = {128, 1024};
wrong_case1.matmul_input1_shape = {1024, 1024};
wrong_case1.transpose_output_shape = {1, 16, 128, 64};
wrong_case1.transpose_x1 = false;
wrong_case1.transpose_x2 = true;
wrong_cases.push_back(std::move(wrong_case1));
// get node shape
auto matmul_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(matmul, 0);
auto matmul_input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(matmul, 1);
auto transpose_output_shape = AnfAlgo::GetOutputInferShape(confusion_transpose, 0);
auto transpose_x1 = AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX1);
auto transpose_x2 = AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX2);
// check
return std::any_of(wrong_cases.begin(), wrong_cases.end(),
[matmul_input0_shape, matmul_input1_shape, transpose_output_shape, transpose_x1,
transpose_x2](WrongCase wrong_case) {
return wrong_case.matmul_input0_shape == matmul_input0_shape &&
wrong_case.matmul_input1_shape == matmul_input1_shape &&
wrong_case.transpose_output_shape == transpose_output_shape &&
wrong_case.transpose_x1 == transpose_x1 && wrong_case.transpose_x2 == transpose_x2;
});
}
} // namespace
void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode,
const session::KernelGraph & /* kernel_graph */,
FusedNodeRecord *candidate_fusion) {
@ -31,6 +74,9 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
MS_EXCEPTION_IF_NULL(matmul);
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {
if (CheckWrongShape(matmul, cnode)) {
return;
}
mindspore::HashSet<AnfNodePtr> record{cnode, matmul};
candidate_fusion->push_back(record);
SetRecordFusionId(record);