[InstCombine] canonicalize funnel shift constant shift amount to be modulo bitwidth

The shift argument is defined to be modulo the bitwidth, so if that argument
is a constant, we can always reduce the constant to its minimal form to allow
better CSE and other follow-on transforms.

We need to be careful to ignore constant expressions here, or we will likely
infinite loop. I'm adding a general vector constant query for that case.

Differential Revision: https://reviews.llvm.org/D59374

llvm-svn: 356192
This commit is contained in:
Sanjay Patel 2019-03-14 19:22:08 +00:00
parent 6e86216531
commit de1d5d3675
5 changed files with 56 additions and 9 deletions

View File

@ -90,6 +90,10 @@ public:
/// elements. /// elements.
bool containsUndefElement() const; bool containsUndefElement() const;
/// Return true if this is a vector constant that includes any constant
/// expressions.
bool containsConstantExpression() const;
/// Return true if evaluation of this constant could trap. This is true for /// Return true if evaluation of this constant could trap. This is true for
/// things like constant expressions that could divide by zero. /// things like constant expressions that could divide by zero.
bool canTrap() const; bool canTrap() const;

View File

@ -4917,7 +4917,6 @@ static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd,
const APInt *ShAmtC; const APInt *ShAmtC;
if (match(ShAmtArg, m_APInt(ShAmtC))) { if (match(ShAmtArg, m_APInt(ShAmtC))) {
// If there's effectively no shift, return the 1st arg or 2nd arg. // If there's effectively no shift, return the 1st arg or 2nd arg.
// TODO: For vectors, we could check each element of a non-splat constant.
APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth()); APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth());
if (ShAmtC->urem(BitWidth).isNullValue()) if (ShAmtC->urem(BitWidth).isNullValue())
return ArgBegin[IID == Intrinsic::fshl ? 0 : 1]; return ArgBegin[IID == Intrinsic::fshl ? 0 : 1];

View File

@ -260,6 +260,16 @@ bool Constant::containsUndefElement() const {
return false; return false;
} }
bool Constant::containsConstantExpression() const {
if (!getType()->isVectorTy())
return false;
for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i)
if (isa<ConstantExpr>(getAggregateElement(i)))
return true;
return false;
}
/// Constructor to create a '0' constant of arbitrary type. /// Constructor to create a '0' constant of arbitrary type.
Constant *Constant::getNullValue(Type *Ty) { Constant *Constant::getNullValue(Type *Ty) {
switch (Ty->getTypeID()) { switch (Ty->getTypeID()) {

View File

@ -1994,10 +1994,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::fshl: case Intrinsic::fshl:
case Intrinsic::fshr: { case Intrinsic::fshr: {
// Canonicalize a shift amount constant operand to be modulo the bit-width.
unsigned BitWidth = II->getType()->getScalarSizeInBits();
Constant *ShAmtC;
if (match(II->getArgOperand(2), m_Constant(ShAmtC)) &&
!isa<ConstantExpr>(ShAmtC) && !ShAmtC->containsConstantExpression()) {
Constant *WidthC = ConstantInt::get(II->getType(), BitWidth);
Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC);
if (ModuloC != ShAmtC) {
II->setArgOperand(2, ModuloC);
return II;
}
}
const APInt *SA; const APInt *SA;
if (match(II->getArgOperand(2), m_APInt(SA))) { if (match(II->getArgOperand(2), m_APInt(SA))) {
Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
unsigned BitWidth = SA->getBitWidth();
uint64_t ShiftAmt = SA->urem(BitWidth); uint64_t ShiftAmt = SA->urem(BitWidth);
assert(ShiftAmt != 0 && "SimplifyCall should have handled zero shift"); assert(ShiftAmt != 0 && "SimplifyCall should have handled zero shift");
// Normalize to funnel shift left. // Normalize to funnel shift left.
@ -2020,7 +2032,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// The shift amount (operand 2) of a funnel shift is modulo the bitwidth, // The shift amount (operand 2) of a funnel shift is modulo the bitwidth,
// so only the low bits of the shift amount are demanded if the bitwidth is // so only the low bits of the shift amount are demanded if the bitwidth is
// a power-of-2. // a power-of-2.
unsigned BitWidth = II->getType()->getScalarSizeInBits();
if (!isPowerOf2_32(BitWidth)) if (!isPowerOf2_32(BitWidth))
break; break;
APInt Op2Demanded = APInt::getLowBitsSet(BitWidth, Log2_32_Ceil(BitWidth)); APInt Op2Demanded = APInt::getLowBitsSet(BitWidth, Log2_32_Ceil(BitWidth));

View File

@ -310,7 +310,7 @@ define <2 x i31> @fshl_only_op1_demanded_vec_splat(<2 x i31> %x, <2 x i31> %y) {
define i32 @fshl_constant_shift_amount_modulo_bitwidth(i32 %x, i32 %y) { define i32 @fshl_constant_shift_amount_modulo_bitwidth(i32 %x, i32 %y) {
; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth( ; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth(
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.fshl.i32(i32 [[X:%.*]], i32 [[Y:%.*]], i32 33) ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.fshl.i32(i32 [[X:%.*]], i32 [[Y:%.*]], i32 1)
; CHECK-NEXT: ret i32 [[R]] ; CHECK-NEXT: ret i32 [[R]]
; ;
%r = call i32 @llvm.fshl.i32(i32 %x, i32 %y, i32 33) %r = call i32 @llvm.fshl.i32(i32 %x, i32 %y, i32 33)
@ -319,16 +319,28 @@ define i32 @fshl_constant_shift_amount_modulo_bitwidth(i32 %x, i32 %y) {
define i33 @fshr_constant_shift_amount_modulo_bitwidth(i33 %x, i33 %y) { define i33 @fshr_constant_shift_amount_modulo_bitwidth(i33 %x, i33 %y) {
; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth( ; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth(
; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 34) ; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 1)
; CHECK-NEXT: ret i33 [[R]] ; CHECK-NEXT: ret i33 [[R]]
; ;
%r = call i33 @llvm.fshr.i33(i33 %x, i33 %y, i33 34) %r = call i33 @llvm.fshr.i33(i33 %x, i33 %y, i33 34)
ret i33 %r ret i33 %r
} }
@external_global = external global i8
define i33 @fshr_constant_shift_amount_modulo_bitwidth_constexpr(i33 %x, i33 %y) {
; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth_constexpr(
; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 ptrtoint (i8* @external_global to i33))
; CHECK-NEXT: ret i33 [[R]]
;
%shamt = ptrtoint i8* @external_global to i33
%r = call i33 @llvm.fshr.i33(i33 %x, i33 %y, i33 %shamt)
ret i33 %r
}
define <2 x i32> @fshr_constant_shift_amount_modulo_bitwidth_vec(<2 x i32> %x, <2 x i32> %y) { define <2 x i32> @fshr_constant_shift_amount_modulo_bitwidth_vec(<2 x i32> %x, <2 x i32> %y) {
; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth_vec( ; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth_vec(
; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> <i32 34, i32 -1>) ; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> <i32 2, i32 31>)
; CHECK-NEXT: ret <2 x i32> [[R]] ; CHECK-NEXT: ret <2 x i32> [[R]]
; ;
%r = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> %x, <2 x i32> %y, <2 x i32> <i32 34, i32 -1>) %r = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> %x, <2 x i32> %y, <2 x i32> <i32 34, i32 -1>)
@ -373,17 +385,28 @@ define <2 x i32> @fshr_constant_shift_amount_modulo_bitwidth_vec(<2 x i32> %x, <
define <2 x i31> @fshl_constant_shift_amount_modulo_bitwidth_vec(<2 x i31> %x, <2 x i31> %y) { define <2 x i31> @fshl_constant_shift_amount_modulo_bitwidth_vec(<2 x i31> %x, <2 x i31> %y) {
; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth_vec( ; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth_vec(
; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> <i31 34, i31 -1>) ; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> <i31 3, i31 1>)
; CHECK-NEXT: ret <2 x i31> [[R]] ; CHECK-NEXT: ret <2 x i31> [[R]]
; ;
%r = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> %x, <2 x i31> %y, <2 x i31> <i31 34, i31 -1>) %r = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> %x, <2 x i31> %y, <2 x i31> <i31 34, i31 -1>)
ret <2 x i31> %r ret <2 x i31> %r
} }
; The shift modulo bitwidth is the same for all vector elements, but this is not simplified yet. define <2 x i31> @fshl_constant_shift_amount_modulo_bitwidth_vec_const_expr(<2 x i31> %x, <2 x i31> %y) {
; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth_vec_const_expr(
; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> <i31 34, i31 ptrtoint (i8* @external_global to i31)>)
; CHECK-NEXT: ret <2 x i31> [[R]]
;
%shamt = ptrtoint i8* @external_global to i31
%r = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> %x, <2 x i31> %y, <2 x i31> <i31 34, i31 ptrtoint (i8* @external_global to i31)>)
ret <2 x i31> %r
}
; The shift modulo bitwidth is the same for all vector elements.
define <2 x i31> @fshl_only_op1_demanded_vec_nonsplat(<2 x i31> %x, <2 x i31> %y) { define <2 x i31> @fshl_only_op1_demanded_vec_nonsplat(<2 x i31> %x, <2 x i31> %y) {
; CHECK-LABEL: @fshl_only_op1_demanded_vec_nonsplat( ; CHECK-LABEL: @fshl_only_op1_demanded_vec_nonsplat(
; CHECK-NEXT: [[Z:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> <i31 7, i31 38>) ; CHECK-NEXT: [[Z:%.*]] = lshr <2 x i31> [[Y:%.*]], <i31 24, i31 24>
; CHECK-NEXT: [[R:%.*]] = and <2 x i31> [[Z]], <i31 63, i31 31> ; CHECK-NEXT: [[R:%.*]] = and <2 x i31> [[Z]], <i31 63, i31 31>
; CHECK-NEXT: ret <2 x i31> [[R]] ; CHECK-NEXT: ret <2 x i31> [[R]]
; ;