From 9ca028c2d6c492661c377ee2a4ab1cc62f8c2ecd Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Sun, 23 Oct 2016 23:13:31 +0000 Subject: [PATCH] [DAG] enhance computeKnownBits to handle SRL/SRA with vector splat constant llvm-svn: 284953 --- .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 75 ++++++++----------- llvm/test/CodeGen/X86/combine-sra.ll | 7 +- llvm/test/CodeGen/X86/combine-srl.ll | 20 +---- 3 files changed, 38 insertions(+), 64 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 974322aabc16..efb357a07c0b 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1996,6 +1996,18 @@ bool SelectionDAG::MaskedValueIsZero(SDValue Op, const APInt &Mask, return (KnownZero & Mask) == Mask; } +/// If a SHL/SRA/SRL node has a constant or splat constant shift amount that +/// is less than the element bit-width of the shift node, return it. +static const APInt *getValidShiftAmountConstant(SDValue V) { + if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1))) { + // Shifting more than the bitwidth is not valid. + const APInt &ShAmt = SA->getAPIntValue(); + if (ShAmt.ult(V.getScalarValueSizeInBits())) + return &ShAmt; + } + return nullptr; +} + /// Determine which bits of Op are known to be either zero or one and return /// them in the KnownZero/KnownOne bitsets. void SelectionDAG::computeKnownBits(SDValue Op, APInt &KnownZero, @@ -2144,57 +2156,34 @@ void SelectionDAG::computeKnownBits(SDValue Op, APInt &KnownZero, KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - 1); break; case ISD::SHL: - if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) { - // If the shift count is an invalid immediate, don't do anything. - APInt ShAmt = SA->getAPIntValue(); - if (ShAmt.uge(BitWidth)) - break; - + if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) { computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1); - KnownZero = KnownZero << ShAmt; - KnownOne = KnownOne << ShAmt; - // low bits known zero. - KnownZero |= APInt::getLowBitsSet(BitWidth, ShAmt.getZExtValue()); + KnownZero = KnownZero << *ShAmt; + KnownOne = KnownOne << *ShAmt; + // Low bits are known zero. + KnownZero |= APInt::getLowBitsSet(BitWidth, ShAmt->getZExtValue()); } break; case ISD::SRL: - // FIXME: Reuse isConstOrConstSplat + APInt from above. - if (ConstantSDNode *SA = dyn_cast(Op.getOperand(1))) { - unsigned ShAmt = SA->getZExtValue(); - - // If the shift count is an invalid immediate, don't do anything. - if (ShAmt >= BitWidth) - break; - - computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1); - KnownZero = KnownZero.lshr(ShAmt); - KnownOne = KnownOne.lshr(ShAmt); - - APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt); - KnownZero |= HighBits; // High bits known zero. + if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) { + computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1); + KnownZero = KnownZero.lshr(*ShAmt); + KnownOne = KnownOne.lshr(*ShAmt); + // High bits are known zero. + APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt->getZExtValue()); + KnownZero |= HighBits; } break; case ISD::SRA: - // FIXME: Reuse isConstOrConstSplat + APInt from above. - if (ConstantSDNode *SA = dyn_cast(Op.getOperand(1))) { - unsigned ShAmt = SA->getZExtValue(); - - // If the shift count is an invalid immediate, don't do anything. - if (ShAmt >= BitWidth) - break; - - // If any of the demanded bits are produced by the sign extension, we also - // demand the input sign bit. - APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt); - - computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1); - KnownZero = KnownZero.lshr(ShAmt); - KnownOne = KnownOne.lshr(ShAmt); - - // Handle the sign bits. + if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) { + computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1); + KnownZero = KnownZero.lshr(*ShAmt); + KnownOne = KnownOne.lshr(*ShAmt); + // If we know the value of the sign bit, then we know it is copied across + // the high bits by the shift amount. + APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt->getZExtValue()); APInt SignBit = APInt::getSignBit(BitWidth); - SignBit = SignBit.lshr(ShAmt); // Adjust to where it is now in the mask. - + SignBit = SignBit.lshr(*ShAmt); // Adjust to where it is now in the mask. if (KnownZero.intersects(SignBit)) { KnownZero |= HighBits; // New bits are known zero. } else if (KnownOne.intersects(SignBit)) { diff --git a/llvm/test/CodeGen/X86/combine-sra.ll b/llvm/test/CodeGen/X86/combine-sra.ll index 99051d554f46..10b5b67b9de6 100644 --- a/llvm/test/CodeGen/X86/combine-sra.ll +++ b/llvm/test/CodeGen/X86/combine-sra.ll @@ -300,15 +300,12 @@ define <4 x i32> @combine_vec_ashr_positive(<4 x i32> %x, <4 x i32> %y) { define <4 x i32> @combine_vec_ashr_positive_splat(<4 x i32> %x, <4 x i32> %y) { ; SSE-LABEL: combine_vec_ashr_positive_splat: ; SSE: # BB#0: -; SSE-NEXT: pand {{.*}}(%rip), %xmm0 -; SSE-NEXT: psrld $10, %xmm0 +; SSE-NEXT: xorps %xmm0, %xmm0 ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_ashr_positive_splat: ; AVX: # BB#0: -; AVX-NEXT: vpbroadcastd {{.*}}(%rip), %xmm1 -; AVX-NEXT: vpand %xmm1, %xmm0, %xmm0 -; AVX-NEXT: vpsrld $10, %xmm0, %xmm0 +; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; AVX-NEXT: retq %1 = and <4 x i32> %x, %2 = ashr <4 x i32> %1, diff --git a/llvm/test/CodeGen/X86/combine-srl.ll b/llvm/test/CodeGen/X86/combine-srl.ll index 0c76583a4c43..b65a5c83bcf1 100644 --- a/llvm/test/CodeGen/X86/combine-srl.ll +++ b/llvm/test/CodeGen/X86/combine-srl.ll @@ -79,15 +79,12 @@ define <4 x i32> @combine_vec_lshr_by_zero(<4 x i32> %x) { define <4 x i32> @combine_vec_lshr_known_zero0(<4 x i32> %x) { ; SSE-LABEL: combine_vec_lshr_known_zero0: ; SSE: # BB#0: -; SSE-NEXT: pand {{.*}}(%rip), %xmm0 -; SSE-NEXT: psrld $4, %xmm0 +; SSE-NEXT: xorps %xmm0, %xmm0 ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_lshr_known_zero0: ; AVX: # BB#0: -; AVX-NEXT: vpbroadcastd {{.*}}(%rip), %xmm1 -; AVX-NEXT: vpand %xmm1, %xmm0, %xmm0 -; AVX-NEXT: vpsrld $4, %xmm0, %xmm0 +; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; AVX-NEXT: retq %1 = and <4 x i32> %x, %2 = lshr <4 x i32> %1, @@ -292,21 +289,12 @@ define <4 x i32> @combine_vec_lshr_trunc_lshr1(<4 x i64> %x) { define <4 x i32> @combine_vec_lshr_trunc_lshr_zero0(<4 x i64> %x) { ; SSE-LABEL: combine_vec_lshr_trunc_lshr_zero0: ; SSE: # BB#0: -; SSE-NEXT: psrlq $48, %xmm0 -; SSE-NEXT: psrlq $48, %xmm1 -; SSE-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,1,0,2] -; SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: psrld $24, %xmm0 +; SSE-NEXT: xorps %xmm0, %xmm0 ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_lshr_trunc_lshr_zero0: ; AVX: # BB#0: -; AVX-NEXT: vpsrlq $48, %ymm0, %ymm0 -; AVX-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[0,2,2,3,4,6,6,7] -; AVX-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,2,2,3] -; AVX-NEXT: vpsrld $24, %xmm0, %xmm0 -; AVX-NEXT: vzeroupper +; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; AVX-NEXT: retq %1 = lshr <4 x i64> %x, %2 = trunc <4 x i64> %1 to <4 x i32>