[X86] Improve combineVectorShiftImm

Summary:
Fold (shift (shift X, C2), C1) -> (shift X, (C1 + C2)) for logical as
well as arithmetic shifts. This is needed to prevent regressions from
an upcoming funnel shift expansion change.

While we're here, fold (VSRAI -1, C) -> -1 too.

Reviewers: RKSimon, craig.topper

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77300
This commit is contained in:
Jay Foad 2020-04-02 12:20:35 +01:00
parent e8111502d8
commit bc78baec4c
4 changed files with 926 additions and 917 deletions

View File

@ -41084,26 +41084,37 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG,
if (ShiftVal >= NumBitsPerElt) {
if (LogicalShift)
return DAG.getConstant(0, SDLoc(N), VT);
else
ShiftVal = NumBitsPerElt - 1;
ShiftVal = NumBitsPerElt - 1;
}
// Shift N0 by zero -> N0.
// (shift X, 0) -> X
if (!ShiftVal)
return N0;
// Shift zero -> zero.
// (shift 0, C) -> 0
if (ISD::isBuildVectorAllZeros(N0.getNode()))
// N0 is all zeros or undef. We guarantee that the bits shifted into the
// result are all zeros, not undef.
return DAG.getConstant(0, SDLoc(N), VT);
// Fold (VSRAI (VSRAI X, C1), C2) --> (VSRAI X, (C1 + C2)) with (C1 + C2)
// clamped to (NumBitsPerElt - 1).
if (Opcode == X86ISD::VSRAI && N0.getOpcode() == X86ISD::VSRAI) {
// (VSRAI -1, C) -> -1
if (!LogicalShift && ISD::isBuildVectorAllOnes(N0.getNode()))
// N0 is all ones or undef. We guarantee that the bits shifted into the
// result are all ones, not undef.
return DAG.getConstant(-1, SDLoc(N), VT);
// (shift (shift X, C2), C1) -> (shift X, (C1 + C2))
if (Opcode == N0.getOpcode()) {
unsigned ShiftVal2 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue();
unsigned NewShiftVal = ShiftVal + ShiftVal2;
if (NewShiftVal >= NumBitsPerElt)
if (NewShiftVal >= NumBitsPerElt) {
// Out of range logical bit shifts are guaranteed to be zero.
// Out of range arithmetic bit shifts splat the sign bit.
if (LogicalShift)
return DAG.getConstant(0, SDLoc(N), VT);
NewShiftVal = NumBitsPerElt - 1;
return DAG.getNode(X86ISD::VSRAI, SDLoc(N), VT, N0.getOperand(0),
}
return DAG.getNode(Opcode, SDLoc(N), VT, N0.getOperand(0),
DAG.getTargetConstant(NewShiftVal, SDLoc(N), MVT::i8));
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -146,16 +146,16 @@ define <8 x i64> @vec512_i64_signed_reg_reg(<8 x i64> %a1, <8 x i64> %a2) nounwi
; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2
; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1
; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2
; ALL-NEXT: vpmuludq %zmm2, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm1, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $33, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm1, %zmm4, %zmm1
; ALL-NEXT: vpsllq $32, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq
%t3 = icmp sgt <8 x i64> %a1, %a2 ; signed
%t4 = select <8 x i1> %t3, <8 x i64> <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, <8 x i64> <i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1>
@ -178,16 +178,16 @@ define <8 x i64> @vec512_i64_unsigned_reg_reg(<8 x i64> %a1, <8 x i64> %a2) noun
; ALL-NEXT: vpminuq %zmm1, %zmm0, %zmm2
; ALL-NEXT: vpmaxuq %zmm1, %zmm0, %zmm1
; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2
; ALL-NEXT: vpmuludq %zmm2, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm1, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $33, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm1, %zmm4, %zmm1
; ALL-NEXT: vpsllq $32, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq
%t3 = icmp ugt <8 x i64> %a1, %a2
%t4 = select <8 x i1> %t3, <8 x i64> <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, <8 x i64> <i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1>
@ -213,16 +213,16 @@ define <8 x i64> @vec512_i64_signed_mem_reg(<8 x i64>* %a1_addr, <8 x i64> %a2)
; ALL-NEXT: vpminsq %zmm0, %zmm1, %zmm2
; ALL-NEXT: vpmaxsq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpsubq %zmm2, %zmm0, %zmm0
; ALL-NEXT: vpsrlq $1, %zmm0, %zmm0
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2
; ALL-NEXT: vpmuludq %zmm2, %zmm0, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm0, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpsrlq $1, %zmm0, %zmm2
; ALL-NEXT: vpsrlq $33, %zmm0, %zmm0
; ALL-NEXT: vpmuludq %zmm3, %zmm0, %zmm0
; ALL-NEXT: vpaddq %zmm1, %zmm2, %zmm1
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm0, %zmm4, %zmm0
; ALL-NEXT: vpsllq $32, %zmm0, %zmm0
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm1, %zmm0, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq
%a1 = load <8 x i64>, <8 x i64>* %a1_addr
%t3 = icmp sgt <8 x i64> %a1, %a2 ; signed
@ -247,16 +247,16 @@ define <8 x i64> @vec512_i64_signed_reg_mem(<8 x i64> %a1, <8 x i64>* %a2_addr)
; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2
; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1
; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2
; ALL-NEXT: vpmuludq %zmm2, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm1, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $33, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm1, %zmm4, %zmm1
; ALL-NEXT: vpsllq $32, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq
%a2 = load <8 x i64>, <8 x i64>* %a2_addr
%t3 = icmp sgt <8 x i64> %a1, %a2 ; signed
@ -282,16 +282,16 @@ define <8 x i64> @vec512_i64_signed_mem_mem(<8 x i64>* %a1_addr, <8 x i64>* %a2_
; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2
; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1
; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2
; ALL-NEXT: vpmuludq %zmm2, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm1, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $33, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm1, %zmm4, %zmm1
; ALL-NEXT: vpsllq $32, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq
%a1 = load <8 x i64>, <8 x i64>* %a1_addr
%a2 = load <8 x i64>, <8 x i64>* %a2_addr