diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h index e71ac273c080..931576651224 100644 --- a/llvm/include/llvm/IR/Constant.h +++ b/llvm/include/llvm/IR/Constant.h @@ -90,6 +90,10 @@ public: /// elements. bool containsUndefElement() const; + /// Return true if this is a vector constant that includes any constant + /// expressions. + bool containsConstantExpression() const; + /// Return true if evaluation of this constant could trap. This is true for /// things like constant expressions that could divide by zero. bool canTrap() const; diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 047f84f3ef0d..16c0b51a380c 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -4917,7 +4917,6 @@ static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, const APInt *ShAmtC; if (match(ShAmtArg, m_APInt(ShAmtC))) { // If there's effectively no shift, return the 1st arg or 2nd arg. - // TODO: For vectors, we could check each element of a non-splat constant. APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth()); if (ShAmtC->urem(BitWidth).isNullValue()) return ArgBegin[IID == Intrinsic::fshl ? 0 : 1]; diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index 00d6cc7ea23f..a1619921f8b4 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -260,6 +260,16 @@ bool Constant::containsUndefElement() const { return false; } +bool Constant::containsConstantExpression() const { + if (!getType()->isVectorTy()) + return false; + for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i) + if (isa(getAggregateElement(i))) + return true; + + return false; +} + /// Constructor to create a '0' constant of arbitrary type. Constant *Constant::getNullValue(Type *Ty) { switch (Ty->getTypeID()) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index dc8fd4766d49..3c8898004bb7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1994,10 +1994,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::fshl: case Intrinsic::fshr: { + // Canonicalize a shift amount constant operand to be modulo the bit-width. + unsigned BitWidth = II->getType()->getScalarSizeInBits(); + Constant *ShAmtC; + if (match(II->getArgOperand(2), m_Constant(ShAmtC)) && + !isa(ShAmtC) && !ShAmtC->containsConstantExpression()) { + Constant *WidthC = ConstantInt::get(II->getType(), BitWidth); + Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC); + if (ModuloC != ShAmtC) { + II->setArgOperand(2, ModuloC); + return II; + } + } + const APInt *SA; if (match(II->getArgOperand(2), m_APInt(SA))) { Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); - unsigned BitWidth = SA->getBitWidth(); uint64_t ShiftAmt = SA->urem(BitWidth); assert(ShiftAmt != 0 && "SimplifyCall should have handled zero shift"); // Normalize to funnel shift left. @@ -2020,7 +2032,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // The shift amount (operand 2) of a funnel shift is modulo the bitwidth, // so only the low bits of the shift amount are demanded if the bitwidth is // a power-of-2. - unsigned BitWidth = II->getType()->getScalarSizeInBits(); if (!isPowerOf2_32(BitWidth)) break; APInt Op2Demanded = APInt::getLowBitsSet(BitWidth, Log2_32_Ceil(BitWidth)); diff --git a/llvm/test/Transforms/InstCombine/fsh.ll b/llvm/test/Transforms/InstCombine/fsh.ll index eaa699cd3426..b913a3d614ba 100644 --- a/llvm/test/Transforms/InstCombine/fsh.ll +++ b/llvm/test/Transforms/InstCombine/fsh.ll @@ -310,7 +310,7 @@ define <2 x i31> @fshl_only_op1_demanded_vec_splat(<2 x i31> %x, <2 x i31> %y) { define i32 @fshl_constant_shift_amount_modulo_bitwidth(i32 %x, i32 %y) { ; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth( -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.fshl.i32(i32 [[X:%.*]], i32 [[Y:%.*]], i32 33) +; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.fshl.i32(i32 [[X:%.*]], i32 [[Y:%.*]], i32 1) ; CHECK-NEXT: ret i32 [[R]] ; %r = call i32 @llvm.fshl.i32(i32 %x, i32 %y, i32 33) @@ -319,16 +319,28 @@ define i32 @fshl_constant_shift_amount_modulo_bitwidth(i32 %x, i32 %y) { define i33 @fshr_constant_shift_amount_modulo_bitwidth(i33 %x, i33 %y) { ; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth( -; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 34) +; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 1) ; CHECK-NEXT: ret i33 [[R]] ; %r = call i33 @llvm.fshr.i33(i33 %x, i33 %y, i33 34) ret i33 %r } +@external_global = external global i8 + +define i33 @fshr_constant_shift_amount_modulo_bitwidth_constexpr(i33 %x, i33 %y) { +; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth_constexpr( +; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 ptrtoint (i8* @external_global to i33)) +; CHECK-NEXT: ret i33 [[R]] +; + %shamt = ptrtoint i8* @external_global to i33 + %r = call i33 @llvm.fshr.i33(i33 %x, i33 %y, i33 %shamt) + ret i33 %r +} + define <2 x i32> @fshr_constant_shift_amount_modulo_bitwidth_vec(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth_vec( -; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> ) +; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> ) ; CHECK-NEXT: ret <2 x i32> [[R]] ; %r = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> %x, <2 x i32> %y, <2 x i32> ) @@ -373,17 +385,28 @@ define <2 x i32> @fshr_constant_shift_amount_modulo_bitwidth_vec(<2 x i32> %x, < define <2 x i31> @fshl_constant_shift_amount_modulo_bitwidth_vec(<2 x i31> %x, <2 x i31> %y) { ; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth_vec( -; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> ) +; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> ) ; CHECK-NEXT: ret <2 x i31> [[R]] ; %r = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> %x, <2 x i31> %y, <2 x i31> ) ret <2 x i31> %r } -; The shift modulo bitwidth is the same for all vector elements, but this is not simplified yet. +define <2 x i31> @fshl_constant_shift_amount_modulo_bitwidth_vec_const_expr(<2 x i31> %x, <2 x i31> %y) { +; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth_vec_const_expr( +; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> ) +; CHECK-NEXT: ret <2 x i31> [[R]] +; + %shamt = ptrtoint i8* @external_global to i31 + %r = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> %x, <2 x i31> %y, <2 x i31> ) + ret <2 x i31> %r +} + +; The shift modulo bitwidth is the same for all vector elements. + define <2 x i31> @fshl_only_op1_demanded_vec_nonsplat(<2 x i31> %x, <2 x i31> %y) { ; CHECK-LABEL: @fshl_only_op1_demanded_vec_nonsplat( -; CHECK-NEXT: [[Z:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> ) +; CHECK-NEXT: [[Z:%.*]] = lshr <2 x i31> [[Y:%.*]], ; CHECK-NEXT: [[R:%.*]] = and <2 x i31> [[Z]], ; CHECK-NEXT: ret <2 x i31> [[R]] ;