[InstCombine][X86] simplifyX86varShift - convert variable in-range per-element shift amounts to generic shifts (PR40391)

AVX2/AVX512 per-element shifts can be replaced with generic shifts if the shift amounts are guaranteed to be in-range (upper bits are known zero).
This commit is contained in:
Simon Pilgrim 2020-03-18 11:26:30 +00:00
parent c5b81466c2
commit f4e495a18e
2 changed files with 26 additions and 14 deletions

View File

@ -472,17 +472,29 @@ static Value *simplifyX86varShift(const IntrinsicInst &II,
}
assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left");
// Simplify if all shift amounts are constant/undef.
auto *CShift = dyn_cast<Constant>(II.getArgOperand(1));
if (!CShift)
return nullptr;
auto Vec = II.getArgOperand(0);
auto Amt = II.getArgOperand(1);
auto VT = cast<VectorType>(II.getType());
auto SVT = VT->getVectorElementType();
int NumElts = VT->getNumElements();
int BitWidth = SVT->getIntegerBitWidth();
// If the shift amount is guaranteed to be in-range we can replace it with a
// generic shift.
APInt UpperBits =
APInt::getHighBitsSet(BitWidth, BitWidth - Log2_32(BitWidth));
if (llvm::MaskedValueIsZero(Amt, UpperBits,
II.getModule()->getDataLayout())) {
return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
: Builder.CreateLShr(Vec, Amt))
: Builder.CreateAShr(Vec, Amt));
}
// Simplify if all shift amounts are constant/undef.
auto *CShift = dyn_cast<Constant>(Amt);
if (!CShift)
return nullptr;
// Collect each element's shift amount.
// We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth.
bool AnyOutOfRange = false;

View File

@ -2681,7 +2681,7 @@ define <32 x i16> @avx512_psllv_w_512_undef(<32 x i16> %v) {
define <4 x i32> @avx2_psrav_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
; CHECK-LABEL: @avx2_psrav_d_128_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 31, i32 31, i32 31>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i32> @llvm.x86.avx2.psrav.d(<4 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = ashr <4 x i32> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
%1 = and <4 x i32> %a, <i32 31, i32 31, i32 31, i32 31>
@ -2692,7 +2692,7 @@ define <4 x i32> @avx2_psrav_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
define <8 x i32> @avx2_psrav_d_256_masked(<8 x i32> %v, <8 x i32> %a) {
; CHECK-LABEL: @avx2_psrav_d_256_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i32> [[A:%.*]], <i32 0, i32 1, i32 7, i32 15, i32 16, i32 30, i32 31, i32 31>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i32> @llvm.x86.avx2.psrav.d.256(<8 x i32> [[V:%.*]], <8 x i32> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = ashr <8 x i32> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <8 x i32> [[TMP2]]
;
%1 = and <8 x i32> %a, <i32 0, i32 1, i32 7, i32 15, i32 16, i32 30, i32 31, i32 31>
@ -2703,7 +2703,7 @@ define <8 x i32> @avx2_psrav_d_256_masked(<8 x i32> %v, <8 x i32> %a) {
define <32 x i16> @avx512_psrav_w_512_masked(<32 x i16> %v, <32 x i16> %a) {
; CHECK-LABEL: @avx512_psrav_w_512_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <32 x i16> [[A:%.*]], <i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <32 x i16> @llvm.x86.avx512.psrav.w.512(<32 x i16> [[V:%.*]], <32 x i16> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = ashr <32 x i16> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <32 x i16> [[TMP2]]
;
%1 = and <32 x i16> %a, <i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7>
@ -2714,7 +2714,7 @@ define <32 x i16> @avx512_psrav_w_512_masked(<32 x i16> %v, <32 x i16> %a) {
define <2 x i64> @avx2_psrlv_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
; CHECK-LABEL: @avx2_psrlv_q_128_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 32, i64 63>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <2 x i64> @llvm.x86.avx2.psrlv.q(<2 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = lshr <2 x i64> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <2 x i64> [[TMP2]]
;
%1 = and <2 x i64> %a, <i64 32, i64 63>
@ -2725,7 +2725,7 @@ define <2 x i64> @avx2_psrlv_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
define <8 x i32> @avx2_psrlv_d_256_masked(<8 x i32> %v, <8 x i32> %a) {
; CHECK-LABEL: @avx2_psrlv_d_256_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i32> [[A:%.*]], <i32 0, i32 1, i32 7, i32 15, i32 16, i32 30, i32 31, i32 31>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i32> @llvm.x86.avx2.psrlv.d.256(<8 x i32> [[V:%.*]], <8 x i32> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = lshr <8 x i32> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <8 x i32> [[TMP2]]
;
%1 = and <8 x i32> %a, <i32 0, i32 1, i32 7, i32 15, i32 16, i32 30, i32 31, i32 31>
@ -2736,7 +2736,7 @@ define <8 x i32> @avx2_psrlv_d_256_masked(<8 x i32> %v, <8 x i32> %a) {
define <8 x i64> @avx512_psrlv_q_512_masked(<8 x i64> %v, <8 x i64> %a) {
; CHECK-LABEL: @avx512_psrlv_q_512_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i64> [[A:%.*]], <i64 0, i64 1, i64 4, i64 16, i64 32, i64 47, i64 62, i64 63>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i64> @llvm.x86.avx512.psrlv.q.512(<8 x i64> [[V:%.*]], <8 x i64> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = lshr <8 x i64> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <8 x i64> [[TMP2]]
;
%1 = and <8 x i64> %a, <i64 0, i64 1, i64 4, i64 16, i64 32, i64 47, i64 62, i64 63>
@ -2747,7 +2747,7 @@ define <8 x i64> @avx512_psrlv_q_512_masked(<8 x i64> %v, <8 x i64> %a) {
define <4 x i32> @avx2_psllv_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
; CHECK-LABEL: @avx2_psllv_d_128_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 0, i32 15, i32 16, i32 31>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i32> @llvm.x86.avx2.psllv.d(<4 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = shl <4 x i32> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
%1 = and <4 x i32> %a, <i32 0, i32 15, i32 16, i32 31>
@ -2758,7 +2758,7 @@ define <4 x i32> @avx2_psllv_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
define <4 x i64> @avx2_psllv_q_256_masked(<4 x i64> %v, <4 x i64> %a) {
; CHECK-LABEL: @avx2_psllv_q_256_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i64> [[A:%.*]], <i64 0, i64 16, i64 32, i64 63>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i64> @llvm.x86.avx2.psllv.q.256(<4 x i64> [[V:%.*]], <4 x i64> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = shl <4 x i64> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <4 x i64> [[TMP2]]
;
%1 = and <4 x i64> %a, <i64 0, i64 16, i64 32, i64 63>
@ -2769,7 +2769,7 @@ define <4 x i64> @avx2_psllv_q_256_masked(<4 x i64> %v, <4 x i64> %a) {
define <32 x i16> @avx512_psllv_w_512_masked(<32 x i16> %v, <32 x i16> %a) {
; CHECK-LABEL: @avx512_psllv_w_512_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <32 x i16> [[A:%.*]], <i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7>
; CHECK-NEXT: [[TMP2:%.*]] = tail call <32 x i16> @llvm.x86.avx512.psllv.w.512(<32 x i16> [[V:%.*]], <32 x i16> [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = shl <32 x i16> [[V:%.*]], [[TMP1]]
; CHECK-NEXT: ret <32 x i16> [[TMP2]]
;
%1 = and <32 x i16> %a, <i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7>