add case check for MatmulConfusionTranposeFusionPass
This commit is contained in:
parent
f1c1acd681
commit
ffb9bbbbd8
|
@ -22,6 +22,49 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kAttrTransposeX1 = "transpose_x1";
|
||||
constexpr auto kAttrTransposeX2 = "transpose_x2";
|
||||
|
||||
struct WrongCase {
|
||||
std::vector<size_t> matmul_input0_shape;
|
||||
std::vector<size_t> matmul_input1_shape;
|
||||
std::vector<size_t> transpose_output_shape;
|
||||
bool transpose_x1;
|
||||
bool transpose_x2;
|
||||
};
|
||||
|
||||
bool CheckWrongShape(const AnfNodePtr &matmul, const AnfNodePtr &confusion_transpose) {
|
||||
std::vector<WrongCase> wrong_cases;
|
||||
|
||||
// add wrong cases
|
||||
WrongCase wrong_case1;
|
||||
wrong_case1.matmul_input0_shape = {128, 1024};
|
||||
wrong_case1.matmul_input1_shape = {1024, 1024};
|
||||
wrong_case1.transpose_output_shape = {1, 16, 128, 64};
|
||||
wrong_case1.transpose_x1 = false;
|
||||
wrong_case1.transpose_x2 = true;
|
||||
wrong_cases.push_back(std::move(wrong_case1));
|
||||
|
||||
// get node shape
|
||||
auto matmul_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(matmul, 0);
|
||||
auto matmul_input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(matmul, 1);
|
||||
auto transpose_output_shape = AnfAlgo::GetOutputInferShape(confusion_transpose, 0);
|
||||
auto transpose_x1 = AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX1);
|
||||
auto transpose_x2 = AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX2);
|
||||
|
||||
// check
|
||||
return std::any_of(wrong_cases.begin(), wrong_cases.end(),
|
||||
[matmul_input0_shape, matmul_input1_shape, transpose_output_shape, transpose_x1,
|
||||
transpose_x2](WrongCase wrong_case) {
|
||||
return wrong_case.matmul_input0_shape == matmul_input0_shape &&
|
||||
wrong_case.matmul_input1_shape == matmul_input1_shape &&
|
||||
wrong_case.transpose_output_shape == transpose_output_shape &&
|
||||
wrong_case.transpose_x1 == transpose_x1 && wrong_case.transpose_x2 == transpose_x2;
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode,
|
||||
const session::KernelGraph & /* kernel_graph */,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
|
@ -31,6 +74,9 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
|
|||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
|
||||
AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {
|
||||
if (CheckWrongShape(matmul, cnode)) {
|
||||
return;
|
||||
}
|
||||
mindspore::HashSet<AnfNodePtr> record{cnode, matmul};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
Loading…
Reference in New Issue