From 64c41301ce4c3bfc1f5d42423595b9084e36a824 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 1 Mar 2021 09:40:43 +0000 Subject: [PATCH] [DAG] visitVECTOR_SHUFFLE - move shuffle canonicalization/merges all under the same legality test. NFCI. Minor cleanup to move related combines closer together to make it more coherent, without changing the ordering. --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 118 +++++++++--------- 1 file changed, 58 insertions(+), 60 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index faaa28963a2c..d896b8c0cdef 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -21043,37 +21043,6 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { } } - if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) { - // Canonicalize shuffles according to rules: - // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A) - // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B) - // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B) - if (N1.getOpcode() == ISD::VECTOR_SHUFFLE && - N0.getOpcode() != ISD::VECTOR_SHUFFLE) { - // The incoming shuffle must be of the same type as the result of the - // current shuffle. - assert(N1->getOperand(0).getValueType() == VT && - "Shuffle types don't match"); - - SDValue SV0 = N1->getOperand(0); - SDValue SV1 = N1->getOperand(1); - bool HasSameOp0 = N0 == SV0; - bool IsSV1Undef = SV1.isUndef(); - if (HasSameOp0 || IsSV1Undef || N0 == SV1) - // Commute the operands of this shuffle so merging below will trigger. - return DAG.getCommutedVectorShuffle(*SVN); - } - - // Canonicalize splat shuffles to the RHS to improve merging below. - // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u)) - if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && - N1.getOpcode() == ISD::VECTOR_SHUFFLE && - cast(N0)->isSplat() && - !cast(N1)->isSplat()) { - return DAG.getCommutedVectorShuffle(*SVN); - } - } - // Compute the combined shuffle mask for a shuffle with SV0 as the first // operand, and SV1 as the second operand. // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false @@ -21191,36 +21160,65 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { return TLI.isShuffleMaskLegal(Mask, VT); }; - // Try to fold according to rules: - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2) - // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2) - // Don't try to fold shuffles with illegal type. - // Only fold if this shuffle is the only user of the other shuffle. - if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) && - Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) { - // The incoming shuffle must be of the same type as the result of the - // current shuffle. - auto *OtherSV = cast(N0); - assert(OtherSV->getOperand(0).getValueType() == VT && - "Shuffle types don't match"); - - SDValue SV0, SV1; - SmallVector Mask; - if (MergeInnerShuffle(false, SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) { - // Check if all indices in Mask are Undef. In case, propagate Undef. - if (llvm::all_of(Mask, [](int M) { return M < 0; })) - return DAG.getUNDEF(VT); - - return DAG.getVectorShuffle(VT, SDLoc(N), SV0 ? SV0 : DAG.getUNDEF(VT), - SV1 ? SV1 : DAG.getUNDEF(VT), Mask); - } - } - - // Merge shuffles through binops if we are able to merge it with at least one - // other shuffles. - // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d))) if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) { + // Canonicalize shuffles according to rules: + // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A) + // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B) + // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B) + if (N1.getOpcode() == ISD::VECTOR_SHUFFLE && + N0.getOpcode() != ISD::VECTOR_SHUFFLE) { + // The incoming shuffle must be of the same type as the result of the + // current shuffle. + assert(N1->getOperand(0).getValueType() == VT && + "Shuffle types don't match"); + + SDValue SV0 = N1->getOperand(0); + SDValue SV1 = N1->getOperand(1); + bool HasSameOp0 = N0 == SV0; + bool IsSV1Undef = SV1.isUndef(); + if (HasSameOp0 || IsSV1Undef || N0 == SV1) + // Commute the operands of this shuffle so merging below will trigger. + return DAG.getCommutedVectorShuffle(*SVN); + } + + // Canonicalize splat shuffles to the RHS to improve merging below. + // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u)) + if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && + N1.getOpcode() == ISD::VECTOR_SHUFFLE && + cast(N0)->isSplat() && + !cast(N1)->isSplat()) { + return DAG.getCommutedVectorShuffle(*SVN); + } + + // Try to fold according to rules: + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2) + // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2) + // Don't try to fold shuffles with illegal type. + // Only fold if this shuffle is the only user of the other shuffle. + if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && + N->isOnlyUserOf(N0.getNode())) { + // The incoming shuffle must be of the same type as the result of the + // current shuffle. + auto *OtherSV = cast(N0); + assert(OtherSV->getOperand(0).getValueType() == VT && + "Shuffle types don't match"); + + SDValue SV0, SV1; + SmallVector Mask; + if (MergeInnerShuffle(false, SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) { + // Check if all indices in Mask are Undef. In case, propagate Undef. + if (llvm::all_of(Mask, [](int M) { return M < 0; })) + return DAG.getUNDEF(VT); + + return DAG.getVectorShuffle(VT, SDLoc(N), SV0 ? SV0 : DAG.getUNDEF(VT), + SV1 ? SV1 : DAG.getUNDEF(VT), Mask); + } + } + + // Merge shuffles through binops if we are able to merge it with at least + // one other shuffles. + // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d))) unsigned SrcOpcode = N0.getOpcode(); if (SrcOpcode == N1.getOpcode() && TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) && N->isOnlyUserOf(N1.getNode())) {