From 5c91aa660386ea22e0d38eb0de4c26d62073ccb3 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Tue, 5 May 2020 12:29:57 +0100 Subject: [PATCH] [InstCombine] Fold or(zext(bswap(x)),shl(zext(bswap(y)),bw/2)) -> bswap(or(zext(x),shl(zext(y), bw/2)) This adds a general combine that can be used to fold: or(zext(OP(x)), shl(zext(OP(y)),bw/2)) --> OP(or(zext(x), shl(zext(y),bw/2))) Allowing us to widen 'concat-able' style or+zext patterns - I've just set this up for BSWAP but we could use this for other similar ops (BITREVERSE for instance). We already do something similar for bitop(bswap(x),bswap(y)) --> bswap(bitop(x,y)) Fixes PR45715 Reviewed By: @lebedev.ri Differential Revision: https://reviews.llvm.org/D79041 --- .../InstCombine/InstCombineAndOrXor.cpp | 46 +++++++++++++++++++ llvm/test/Transforms/InstCombine/or-concat.ll | 38 +++++---------- 2 files changed, 58 insertions(+), 26 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 6cc6dcdd748a..a4d86d751c2f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2132,6 +2132,49 @@ static Instruction *matchRotate(Instruction &Or) { return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt }); } +/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. +static Instruction *matchOrConcat(Instruction &Or, + InstCombiner::BuilderTy &Builder) { + assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'"); + Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1); + Type *Ty = Or.getType(); + + unsigned Width = Ty->getScalarSizeInBits(); + if ((Width & 1) != 0) + return nullptr; + unsigned HalfWidth = Width / 2; + + // Canonicalize zext (lower half) to LHS. + if (!isa(Op0)) + std::swap(Op0, Op1); + + // Find lower/upper half. + Value *LowerSrc, *ShlVal, *UpperSrc; + const APInt *C; + if (!match(Op0, m_OneUse(m_ZExt(m_Value(LowerSrc)))) || + !match(Op1, m_OneUse(m_Shl(m_Value(ShlVal), m_APInt(C)))) || + !match(ShlVal, m_OneUse(m_ZExt(m_Value(UpperSrc))))) + return nullptr; + if (*C != HalfWidth || LowerSrc->getType() != UpperSrc->getType() || + LowerSrc->getType()->getScalarSizeInBits() != HalfWidth) + return nullptr; + + // Find matching bswap instructions. + // TODO: Add more patterns (bitreverse?) + Value *LowerBSwap, *UpperBSwap; + if (!match(LowerSrc, m_BSwap(m_Value(LowerBSwap))) || + !match(UpperSrc, m_BSwap(m_Value(UpperBSwap)))) + return nullptr; + + // Push the concat down, swapping the lower/upper sources. + Value *NewLower = Builder.CreateZExt(UpperBSwap, Ty); + Value *NewUpper = Builder.CreateZExt(LowerBSwap, Ty); + NewUpper = Builder.CreateShl(NewUpper, HalfWidth); + Value *BinOp = Builder.CreateOr(NewLower, NewUpper); + Function *F = Intrinsic::getDeclaration(Or.getModule(), Intrinsic::bswap, Ty); + return Builder.CreateCall(F, BinOp); +} + /// If all elements of two constant vectors are 0/-1 and inverses, return true. static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { unsigned NumElts = cast(C1->getType())->getNumElements(); @@ -2532,6 +2575,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Instruction *Rotate = matchRotate(I)) return Rotate; + if (Instruction *Concat = matchOrConcat(I, Builder)) + return replaceInstUsesWith(I, Concat); + Value *X, *Y; const APInt *CV; if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && diff --git a/llvm/test/Transforms/InstCombine/or-concat.ll b/llvm/test/Transforms/InstCombine/or-concat.ll index f0d36f2a60e5..77cdaa9a37dd 100644 --- a/llvm/test/Transforms/InstCombine/or-concat.ll +++ b/llvm/test/Transforms/InstCombine/or-concat.ll @@ -13,16 +13,8 @@ ; PR45715 define i64 @concat_bswap32_unary_split(i64 %a0) { ; CHECK-LABEL: @concat_bswap32_unary_split( -; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A0:%.*]], 32 -; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32 -; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[A0]] to i32 -; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[TMP2]]) -; CHECK-NEXT: [[TMP5:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[TMP3]]) -; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP4]] to i64 -; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP5]] to i64 -; CHECK-NEXT: [[TMP8:%.*]] = shl nuw i64 [[TMP7]], 32 -; CHECK-NEXT: [[TMP9:%.*]] = or i64 [[TMP8]], [[TMP6]] -; CHECK-NEXT: ret i64 [[TMP9]] +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.bswap.i64(i64 [[A0:%.*]]) +; CHECK-NEXT: ret i64 [[TMP1]] ; %1 = lshr i64 %a0, 32 %2 = trunc i64 %1 to i32 @@ -39,15 +31,10 @@ define i64 @concat_bswap32_unary_split(i64 %a0) { define i64 @concat_bswap32_unary_flip(i64 %a0) { ; CHECK-LABEL: @concat_bswap32_unary_flip( ; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A0:%.*]], 32 -; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32 -; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[A0]] to i32 -; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[TMP2]]) -; CHECK-NEXT: [[TMP5:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[TMP3]]) -; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP4]] to i64 -; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP5]] to i64 -; CHECK-NEXT: [[TMP8:%.*]] = shl nuw i64 [[TMP6]], 32 -; CHECK-NEXT: [[TMP9:%.*]] = or i64 [[TMP8]], [[TMP7]] -; CHECK-NEXT: ret i64 [[TMP9]] +; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[A0]], 32 +; CHECK-NEXT: [[TMP3:%.*]] = or i64 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.bswap.i64(i64 [[TMP3]]) +; CHECK-NEXT: ret i64 [[TMP4]] ; %1 = lshr i64 %a0, 32 %2 = trunc i64 %1 to i32 @@ -63,13 +50,12 @@ define i64 @concat_bswap32_unary_flip(i64 %a0) { define i64 @concat_bswap32_binary(i32 %a0, i32 %a1) { ; CHECK-LABEL: @concat_bswap32_binary( -; CHECK-NEXT: [[TMP1:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[A0:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[A1:%.*]]) -; CHECK-NEXT: [[TMP3:%.*]] = zext i32 [[TMP1]] to i64 -; CHECK-NEXT: [[TMP4:%.*]] = zext i32 [[TMP2]] to i64 -; CHECK-NEXT: [[TMP5:%.*]] = shl nuw i64 [[TMP4]], 32 -; CHECK-NEXT: [[TMP6:%.*]] = or i64 [[TMP5]], [[TMP3]] -; CHECK-NEXT: ret i64 [[TMP6]] +; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[A1:%.*]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[A0:%.*]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = shl nuw i64 [[TMP2]], 32 +; CHECK-NEXT: [[TMP4:%.*]] = or i64 [[TMP3]], [[TMP1]] +; CHECK-NEXT: [[TMP5:%.*]] = call i64 @llvm.bswap.i64(i64 [[TMP4]]) +; CHECK-NEXT: ret i64 [[TMP5]] ; %1 = tail call i32 @llvm.bswap.i32(i32 %a0) %2 = tail call i32 @llvm.bswap.i32(i32 %a1)