[DAG] enhance computeKnownBits to handle SRL/SRA with vector splat constant

llvm-svn: 284953
This commit is contained in:
Sanjay Patel 2016-10-23 23:13:31 +00:00
parent d11fdad33e
commit 9ca028c2d6
3 changed files with 38 additions and 64 deletions

View File

@ -1996,6 +1996,18 @@ bool SelectionDAG::MaskedValueIsZero(SDValue Op, const APInt &Mask,
return (KnownZero & Mask) == Mask;
}
/// If a SHL/SRA/SRL node has a constant or splat constant shift amount that
/// is less than the element bit-width of the shift node, return it.
static const APInt *getValidShiftAmountConstant(SDValue V) {
if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1))) {
// Shifting more than the bitwidth is not valid.
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.ult(V.getScalarValueSizeInBits()))
return &ShAmt;
}
return nullptr;
}
/// Determine which bits of Op are known to be either zero or one and return
/// them in the KnownZero/KnownOne bitsets.
void SelectionDAG::computeKnownBits(SDValue Op, APInt &KnownZero,
@ -2144,57 +2156,34 @@ void SelectionDAG::computeKnownBits(SDValue Op, APInt &KnownZero,
KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - 1);
break;
case ISD::SHL:
if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) {
// If the shift count is an invalid immediate, don't do anything.
APInt ShAmt = SA->getAPIntValue();
if (ShAmt.uge(BitWidth))
break;
if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) {
computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1);
KnownZero = KnownZero << ShAmt;
KnownOne = KnownOne << ShAmt;
// low bits known zero.
KnownZero |= APInt::getLowBitsSet(BitWidth, ShAmt.getZExtValue());
KnownZero = KnownZero << *ShAmt;
KnownOne = KnownOne << *ShAmt;
// Low bits are known zero.
KnownZero |= APInt::getLowBitsSet(BitWidth, ShAmt->getZExtValue());
}
break;
case ISD::SRL:
// FIXME: Reuse isConstOrConstSplat + APInt from above.
if (ConstantSDNode *SA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
unsigned ShAmt = SA->getZExtValue();
// If the shift count is an invalid immediate, don't do anything.
if (ShAmt >= BitWidth)
break;
computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
KnownZero = KnownZero.lshr(ShAmt);
KnownOne = KnownOne.lshr(ShAmt);
APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
KnownZero |= HighBits; // High bits known zero.
if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) {
computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1);
KnownZero = KnownZero.lshr(*ShAmt);
KnownOne = KnownOne.lshr(*ShAmt);
// High bits are known zero.
APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt->getZExtValue());
KnownZero |= HighBits;
}
break;
case ISD::SRA:
// FIXME: Reuse isConstOrConstSplat + APInt from above.
if (ConstantSDNode *SA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
unsigned ShAmt = SA->getZExtValue();
// If the shift count is an invalid immediate, don't do anything.
if (ShAmt >= BitWidth)
break;
// If any of the demanded bits are produced by the sign extension, we also
// demand the input sign bit.
APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
KnownZero = KnownZero.lshr(ShAmt);
KnownOne = KnownOne.lshr(ShAmt);
// Handle the sign bits.
if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) {
computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1);
KnownZero = KnownZero.lshr(*ShAmt);
KnownOne = KnownOne.lshr(*ShAmt);
// If we know the value of the sign bit, then we know it is copied across
// the high bits by the shift amount.
APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt->getZExtValue());
APInt SignBit = APInt::getSignBit(BitWidth);
SignBit = SignBit.lshr(ShAmt); // Adjust to where it is now in the mask.
SignBit = SignBit.lshr(*ShAmt); // Adjust to where it is now in the mask.
if (KnownZero.intersects(SignBit)) {
KnownZero |= HighBits; // New bits are known zero.
} else if (KnownOne.intersects(SignBit)) {

View File

@ -300,15 +300,12 @@ define <4 x i32> @combine_vec_ashr_positive(<4 x i32> %x, <4 x i32> %y) {
define <4 x i32> @combine_vec_ashr_positive_splat(<4 x i32> %x, <4 x i32> %y) {
; SSE-LABEL: combine_vec_ashr_positive_splat:
; SSE: # BB#0:
; SSE-NEXT: pand {{.*}}(%rip), %xmm0
; SSE-NEXT: psrld $10, %xmm0
; SSE-NEXT: xorps %xmm0, %xmm0
; SSE-NEXT: retq
;
; AVX-LABEL: combine_vec_ashr_positive_splat:
; AVX: # BB#0:
; AVX-NEXT: vpbroadcastd {{.*}}(%rip), %xmm1
; AVX-NEXT: vpand %xmm1, %xmm0, %xmm0
; AVX-NEXT: vpsrld $10, %xmm0, %xmm0
; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0
; AVX-NEXT: retq
%1 = and <4 x i32> %x, <i32 1023, i32 1023, i32 1023, i32 1023>
%2 = ashr <4 x i32> %1, <i32 10, i32 10, i32 10, i32 10>

View File

@ -79,15 +79,12 @@ define <4 x i32> @combine_vec_lshr_by_zero(<4 x i32> %x) {
define <4 x i32> @combine_vec_lshr_known_zero0(<4 x i32> %x) {
; SSE-LABEL: combine_vec_lshr_known_zero0:
; SSE: # BB#0:
; SSE-NEXT: pand {{.*}}(%rip), %xmm0
; SSE-NEXT: psrld $4, %xmm0
; SSE-NEXT: xorps %xmm0, %xmm0
; SSE-NEXT: retq
;
; AVX-LABEL: combine_vec_lshr_known_zero0:
; AVX: # BB#0:
; AVX-NEXT: vpbroadcastd {{.*}}(%rip), %xmm1
; AVX-NEXT: vpand %xmm1, %xmm0, %xmm0
; AVX-NEXT: vpsrld $4, %xmm0, %xmm0
; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0
; AVX-NEXT: retq
%1 = and <4 x i32> %x, <i32 15, i32 15, i32 15, i32 15>
%2 = lshr <4 x i32> %1, <i32 4, i32 4, i32 4, i32 4>
@ -292,21 +289,12 @@ define <4 x i32> @combine_vec_lshr_trunc_lshr1(<4 x i64> %x) {
define <4 x i32> @combine_vec_lshr_trunc_lshr_zero0(<4 x i64> %x) {
; SSE-LABEL: combine_vec_lshr_trunc_lshr_zero0:
; SSE: # BB#0:
; SSE-NEXT: psrlq $48, %xmm0
; SSE-NEXT: psrlq $48, %xmm1
; SSE-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,1,0,2]
; SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7]
; SSE-NEXT: psrld $24, %xmm0
; SSE-NEXT: xorps %xmm0, %xmm0
; SSE-NEXT: retq
;
; AVX-LABEL: combine_vec_lshr_trunc_lshr_zero0:
; AVX: # BB#0:
; AVX-NEXT: vpsrlq $48, %ymm0, %ymm0
; AVX-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[0,2,2,3,4,6,6,7]
; AVX-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,2,2,3]
; AVX-NEXT: vpsrld $24, %xmm0, %xmm0
; AVX-NEXT: vzeroupper
; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0
; AVX-NEXT: retq
%1 = lshr <4 x i64> %x, <i64 48, i64 48, i64 48, i64 48>
%2 = trunc <4 x i64> %1 to <4 x i32>