diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 16bea7b9fbc5..6d35b1200e06 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -4254,6 +4254,29 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT, N0.getOperand(0)); + // fold (zext (truncate x)) -> (zext x) or + // (zext (truncate x)) -> (truncate x) + // This is valid when the truncated bits of x are already zero. + // FIXME: We should extend this to work for vectors too. + if (N0.getOpcode() == ISD::TRUNCATE && !VT.isVector()) { + SDValue Op = N0.getOperand(0); + APInt TruncatedBits + = APInt::getBitsSet(Op.getValueSizeInBits(), + N0.getValueSizeInBits(), + std::min(Op.getValueSizeInBits(), + VT.getSizeInBits())); + APInt KnownZero, KnownOne; + DAG.ComputeMaskedBits(Op, TruncatedBits, KnownZero, KnownOne); + if (TruncatedBits == KnownZero) { + if (VT.bitsGT(Op.getValueType())) + return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT, Op); + if (VT.bitsLT(Op.getValueType())) + return DAG.getNode(ISD::TRUNCATE, N->getDebugLoc(), VT, Op); + + return Op; + } + } + // fold (zext (truncate (load x))) -> (zext (smaller load x)) // fold (zext (truncate (srl (load x), c))) -> (zext (small load (x+c/n))) if (N0.getOpcode() == ISD::TRUNCATE) { diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 1cc19b25d5e4..64820215aa45 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -725,6 +725,140 @@ bool X86DAGToDAGISel::MatchAddress(SDValue N, X86ISelAddressMode &AM) { return false; } +// Implement some heroics to detect shifts of masked values where the mask can +// be replaced by extending the shift and undoing that in the addressing mode +// scale. Patterns such as (shl (srl x, c1), c2) are canonicalized into (and +// (srl x, SHIFT), MASK) by DAGCombines that don't know the shl can be done in +// the addressing mode. This results in code such as: +// +// int f(short *y, int *lookup_table) { +// ... +// return *y + lookup_table[*y >> 11]; +// } +// +// Turning into: +// movzwl (%rdi), %eax +// movl %eax, %ecx +// shrl $11, %ecx +// addl (%rsi,%rcx,4), %eax +// +// Instead of: +// movzwl (%rdi), %eax +// movl %eax, %ecx +// shrl $9, %ecx +// andl $124, %rcx +// addl (%rsi,%rcx), %eax +// +static bool FoldMaskAndShiftToScale(SelectionDAG &DAG, SDValue N, + X86ISelAddressMode &AM) { + // Scale must not be used already. + if (AM.IndexReg.getNode() != 0 || AM.Scale != 1) return true; + + SDValue Shift = N; + SDValue And = N.getOperand(0); + if (N.getOpcode() != ISD::SRL) + std::swap(Shift, And); + if (Shift.getOpcode() != ISD::SRL || And.getOpcode() != ISD::AND || + !Shift.hasOneUse() || + !isa(Shift.getOperand(1)) || + !isa(And.getOperand(1))) + return true; + SDValue X = (N == Shift ? And.getOperand(0) : Shift.getOperand(0)); + + // We only handle up to 64-bit values here as those are what matter for + // addressing mode optimizations. + if (X.getValueSizeInBits() > 64) return true; + + uint64_t Mask = And.getConstantOperandVal(1); + unsigned ShiftAmt = Shift.getConstantOperandVal(1); + unsigned MaskLZ = CountLeadingZeros_64(Mask); + unsigned MaskTZ = CountTrailingZeros_64(Mask); + + // The amount of shift we're trying to fit into the addressing mode is taken + // from the trailing zeros of the mask. If the mask is pre-shift, we subtract + // the shift amount. + int AMShiftAmt = MaskTZ - (N == Shift ? ShiftAmt : 0); + + // There is nothing we can do here unless the mask is removing some bits. + // Also, the addressing mode can only represent shifts of 1, 2, or 3 bits. + if (AMShiftAmt <= 0 || AMShiftAmt > 3) return true; + + // We also need to ensure that mask is a continuous run of bits. + if (CountTrailingOnes_64(Mask >> MaskTZ) + MaskTZ + MaskLZ != 64) return true; + + // Scale the leading zero count down based on the actual size of the value. + // Also scale it down based on the size of the shift if it was applied + // before the mask. + MaskLZ -= (64 - X.getValueSizeInBits()) + (N == Shift ? 0 : ShiftAmt); + + // The final check is to ensure that any masked out high bits of X are + // already known to be zero. Otherwise, the mask has a semantic impact + // other than masking out a couple of low bits. Unfortunately, because of + // the mask, zero extensions will be removed from operands in some cases. + // This code works extra hard to look through extensions because we can + // replace them with zero extensions cheaply if necessary. + bool ReplacingAnyExtend = false; + if (X.getOpcode() == ISD::ANY_EXTEND) { + unsigned ExtendBits = + X.getValueSizeInBits() - X.getOperand(0).getValueSizeInBits(); + // Assume that we'll replace the any-extend with a zero-extend, and + // narrow the search to the extended value. + X = X.getOperand(0); + MaskLZ = ExtendBits > MaskLZ ? 0 : MaskLZ - ExtendBits; + ReplacingAnyExtend = true; + } + APInt MaskedHighBits = APInt::getHighBitsSet(X.getValueSizeInBits(), + MaskLZ); + APInt KnownZero, KnownOne; + DAG.ComputeMaskedBits(X, MaskedHighBits, KnownZero, KnownOne); + if (MaskedHighBits != KnownZero) return true; + + // We've identified a pattern that can be transformed into a single shift + // and an addressing mode. Make it so. + EVT VT = N.getValueType(); + if (ReplacingAnyExtend) { + assert(X.getValueType() != VT); + // We looked through an ANY_EXTEND node, insert a ZERO_EXTEND. + SDValue NewX = DAG.getNode(ISD::ZERO_EXTEND, X.getDebugLoc(), VT, X); + if (NewX.getNode()->getNodeId() == -1 || + NewX.getNode()->getNodeId() > N.getNode()->getNodeId()) { + DAG.RepositionNode(N.getNode(), NewX.getNode()); + NewX.getNode()->setNodeId(N.getNode()->getNodeId()); + } + X = NewX; + } + DebugLoc DL = N.getDebugLoc(); + SDValue NewSRLAmt = DAG.getConstant(ShiftAmt + AMShiftAmt, MVT::i8); + SDValue NewSRL = DAG.getNode(ISD::SRL, DL, VT, X, NewSRLAmt); + SDValue NewSHLAmt = DAG.getConstant(AMShiftAmt, MVT::i8); + SDValue NewSHL = DAG.getNode(ISD::SHL, DL, VT, NewSRL, NewSHLAmt); + if (NewSRLAmt.getNode()->getNodeId() == -1 || + NewSRLAmt.getNode()->getNodeId() > N.getNode()->getNodeId()) { + DAG.RepositionNode(N.getNode(), NewSRLAmt.getNode()); + NewSRLAmt.getNode()->setNodeId(N.getNode()->getNodeId()); + } + if (NewSRL.getNode()->getNodeId() == -1 || + NewSRL.getNode()->getNodeId() > N.getNode()->getNodeId()) { + DAG.RepositionNode(N.getNode(), NewSRL.getNode()); + NewSRL.getNode()->setNodeId(N.getNode()->getNodeId()); + } + if (NewSHLAmt.getNode()->getNodeId() == -1 || + NewSHLAmt.getNode()->getNodeId() > N.getNode()->getNodeId()) { + DAG.RepositionNode(N.getNode(), NewSHLAmt.getNode()); + NewSHLAmt.getNode()->setNodeId(N.getNode()->getNodeId()); + } + if (NewSHL.getNode()->getNodeId() == -1 || + NewSHL.getNode()->getNodeId() > N.getNode()->getNodeId()) { + DAG.RepositionNode(N.getNode(), NewSHL.getNode()); + NewSHL.getNode()->setNodeId(N.getNode()->getNodeId()); + } + DAG.ReplaceAllUsesWith(N, NewSHL); + + AM.Scale = 1 << AMShiftAmt; + AM.IndexReg = NewSRL; + return false; +} + bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM, unsigned Depth) { DebugLoc dl = N.getDebugLoc(); @@ -814,6 +948,13 @@ bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM, break; } + case ISD::SRL: + // Try to fold the mask and shift into the scale, and return false if we + // succeed. + if (!FoldMaskAndShiftToScale(*CurDAG, N, AM)) + return false; + break; + case ISD::SMUL_LOHI: case ISD::UMUL_LOHI: // A mul_lohi where we need the low part can be folded as a plain multiply. @@ -1047,6 +1188,11 @@ bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM, } } + // Try to fold the mask and shift into the scale, and return false if we + // succeed. + if (!FoldMaskAndShiftToScale(*CurDAG, N, AM)) + return false; + // Handle "(X << C1) & C2" as "(X & (C2>>C1)) << C1" if safe and if this // allows us to fold the shift into this addressing mode. if (Shift.getOpcode() != ISD::SHL) break; diff --git a/llvm/test/CodeGen/X86/fold-and-shift.ll b/llvm/test/CodeGen/X86/fold-and-shift.ll index c42a421a7c48..93baa0e0eee0 100644 --- a/llvm/test/CodeGen/X86/fold-and-shift.ll +++ b/llvm/test/CodeGen/X86/fold-and-shift.ll @@ -31,3 +31,47 @@ entry: %tmp9 = load i32* %tmp78 ret i32 %tmp9 } + +define i32 @t3(i16* %i.ptr, i32* %arr) { +; This case is tricky. The lshr followed by a gep will produce a lshr followed +; by an and to remove the low bits. This can be simplified by doing the lshr by +; a greater constant and using the addressing mode to scale the result back up. +; To make matters worse, because of the two-phase zext of %i and their reuse in +; the function, the DAG can get confusing trying to re-use both of them and +; prevent easy analysis of the mask in order to match this. +; CHECK: t3: +; CHECK-NOT: and +; CHECK: shrl +; CHECK: addl (%{{...}},%{{...}},4), +; CHECK: ret + +entry: + %i = load i16* %i.ptr + %i.zext = zext i16 %i to i32 + %index = lshr i32 %i.zext, 11 + %val.ptr = getelementptr inbounds i32* %arr, i32 %index + %val = load i32* %val.ptr + %sum = add i32 %val, %i.zext + ret i32 %sum +} + +define i32 @t4(i16* %i.ptr, i32* %arr) { +; A version of @t3 that has more zero extends and more re-use of intermediate +; values. This exercise slightly different bits of canonicalization. +; CHECK: t4: +; CHECK-NOT: and +; CHECK: shrl +; CHECK: addl (%{{...}},%{{...}},4), +; CHECK: ret + +entry: + %i = load i16* %i.ptr + %i.zext = zext i16 %i to i32 + %index = lshr i32 %i.zext, 11 + %index.zext = zext i32 %index to i64 + %val.ptr = getelementptr inbounds i32* %arr, i64 %index.zext + %val = load i32* %val.ptr + %sum.1 = add i32 %val, %i.zext + %sum.2 = add i32 %sum.1, %index + ret i32 %sum.2 +}