forked from OSchip/llvm-project
Recognize test for overflow in integer multiplication.
If multiplication involves zero-extended arguments and the result is compared as in the patterns: %mul32 = trunc i64 %mul64 to i32 %zext = zext i32 %mul32 to i64 %overflow = icmp ne i64 %mul64, %zext or %overflow = icmp ugt i64 %mul64 , 0xffffffff then the multiplication may be replaced by call to umul.with.overflow. This change fixes PR4917 and PR4918. Differential Revision: http://llvm-reviews.chandlerc.com/D2814 llvm-svn: 206137
This commit is contained in:
parent
d9963c75da
commit
4bb54d51c8
|
@ -2008,6 +2008,236 @@ static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV,
|
|||
return ExtractValueInst::Create(Call, 1, "uadd.overflow");
|
||||
}
|
||||
|
||||
/// \brief Recognize and process idiom involving test for multiplication
|
||||
/// overflow.
|
||||
///
|
||||
/// The caller has matched a pattern of the form:
|
||||
/// I = cmp u (mul(zext A, zext B), V
|
||||
/// The function checks if this is a test for overflow and if so replaces
|
||||
/// multiplication with call to 'mul.with.overflow' intrinsic.
|
||||
///
|
||||
/// \param I Compare instruction.
|
||||
/// \param MulVal Result of 'mult' instruction. It is one of the arguments of
|
||||
/// the compare instruction. Must be of integer type.
|
||||
/// \param OtherVal The other argument of compare instruction.
|
||||
/// \returns Instruction which must replace the compare instruction, NULL if no
|
||||
/// replacement required.
|
||||
static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
|
||||
Value *OtherVal, InstCombiner &IC) {
|
||||
assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
|
||||
assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
|
||||
assert(isa<IntegerType>(MulVal->getType()));
|
||||
Instruction *MulInstr = cast<Instruction>(MulVal);
|
||||
assert(MulInstr->getOpcode() == Instruction::Mul);
|
||||
|
||||
Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)),
|
||||
*RHS = cast<Instruction>(MulInstr->getOperand(1));
|
||||
assert(LHS->getOpcode() == Instruction::ZExt);
|
||||
assert(RHS->getOpcode() == Instruction::ZExt);
|
||||
Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
|
||||
|
||||
// Calculate type and width of the result produced by mul.with.overflow.
|
||||
Type *TyA = A->getType(), *TyB = B->getType();
|
||||
unsigned WidthA = TyA->getPrimitiveSizeInBits(),
|
||||
WidthB = TyB->getPrimitiveSizeInBits();
|
||||
unsigned MulWidth;
|
||||
Type *MulType;
|
||||
if (WidthB > WidthA) {
|
||||
MulWidth = WidthB;
|
||||
MulType = TyB;
|
||||
} else {
|
||||
MulWidth = WidthA;
|
||||
MulType = TyA;
|
||||
}
|
||||
|
||||
// In order to replace the original mul with a narrower mul.with.overflow,
|
||||
// all uses must ignore upper bits of the product. The number of used low
|
||||
// bits must be not greater than the width of mul.with.overflow.
|
||||
if (MulVal->hasNUsesOrMore(2))
|
||||
for (User *U : MulVal->users()) {
|
||||
if (U == &I)
|
||||
continue;
|
||||
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
|
||||
// Check if truncation ignores bits above MulWidth.
|
||||
unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
|
||||
if (TruncWidth > MulWidth)
|
||||
return 0;
|
||||
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
|
||||
// Check if AND ignores bits above MulWidth.
|
||||
if (BO->getOpcode() != Instruction::And)
|
||||
return 0;
|
||||
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
|
||||
const APInt &CVal = CI->getValue();
|
||||
if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth)
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
// Other uses prohibit this transformation.
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Recognize patterns
|
||||
switch (I.getPredicate()) {
|
||||
case ICmpInst::ICMP_EQ:
|
||||
case ICmpInst::ICMP_NE:
|
||||
// Recognize pattern:
|
||||
// mulval = mul(zext A, zext B)
|
||||
// cmp eq/neq mulval, zext trunc mulval
|
||||
if (ZExtInst *Zext = dyn_cast<ZExtInst>(OtherVal))
|
||||
if (Zext->hasOneUse()) {
|
||||
Value *ZextArg = Zext->getOperand(0);
|
||||
if (TruncInst *Trunc = dyn_cast<TruncInst>(ZextArg))
|
||||
if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth)
|
||||
break; //Recognized
|
||||
}
|
||||
|
||||
// Recognize pattern:
|
||||
// mulval = mul(zext A, zext B)
|
||||
// cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits.
|
||||
ConstantInt *CI;
|
||||
Value *ValToMask;
|
||||
if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) {
|
||||
if (ValToMask != MulVal)
|
||||
return 0;
|
||||
const APInt &CVal = CI->getValue() + 1;
|
||||
if (CVal.isPowerOf2()) {
|
||||
unsigned MaskWidth = CVal.logBase2();
|
||||
if (MaskWidth == MulWidth)
|
||||
break; // Recognized
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
|
||||
case ICmpInst::ICMP_UGT:
|
||||
// Recognize pattern:
|
||||
// mulval = mul(zext A, zext B)
|
||||
// cmp ugt mulval, max
|
||||
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
|
||||
APInt MaxVal = APInt::getMaxValue(MulWidth);
|
||||
MaxVal = MaxVal.zext(CI->getBitWidth());
|
||||
if (MaxVal.eq(CI->getValue()))
|
||||
break; // Recognized
|
||||
}
|
||||
return 0;
|
||||
|
||||
case ICmpInst::ICMP_UGE:
|
||||
// Recognize pattern:
|
||||
// mulval = mul(zext A, zext B)
|
||||
// cmp uge mulval, max+1
|
||||
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
|
||||
APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
|
||||
if (MaxVal.eq(CI->getValue()))
|
||||
break; // Recognized
|
||||
}
|
||||
return 0;
|
||||
|
||||
case ICmpInst::ICMP_ULE:
|
||||
// Recognize pattern:
|
||||
// mulval = mul(zext A, zext B)
|
||||
// cmp ule mulval, max
|
||||
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
|
||||
APInt MaxVal = APInt::getMaxValue(MulWidth);
|
||||
MaxVal = MaxVal.zext(CI->getBitWidth());
|
||||
if (MaxVal.eq(CI->getValue()))
|
||||
break; // Recognized
|
||||
}
|
||||
return 0;
|
||||
|
||||
case ICmpInst::ICMP_ULT:
|
||||
// Recognize pattern:
|
||||
// mulval = mul(zext A, zext B)
|
||||
// cmp ule mulval, max + 1
|
||||
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
|
||||
APInt MaxVal(CI->getBitWidth(), 1ULL << MulWidth);
|
||||
if (MaxVal.eq(CI->getValue()))
|
||||
break; // Recognized
|
||||
}
|
||||
return 0;
|
||||
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
||||
InstCombiner::BuilderTy *Builder = IC.Builder;
|
||||
Builder->SetInsertPoint(MulInstr);
|
||||
Module *M = I.getParent()->getParent()->getParent();
|
||||
|
||||
// Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
|
||||
Value *MulA = A, *MulB = B;
|
||||
if (WidthA < MulWidth)
|
||||
MulA = Builder->CreateZExt(A, MulType);
|
||||
if (WidthB < MulWidth)
|
||||
MulB = Builder->CreateZExt(B, MulType);
|
||||
Value *F =
|
||||
Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType);
|
||||
CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul");
|
||||
IC.Worklist.Add(MulInstr);
|
||||
|
||||
// If there are uses of mul result other than the comparison, we know that
|
||||
// they are truncation or binary AND. Change them to use result of
|
||||
// mul.with.overflow and ajust properly mask/size.
|
||||
if (MulVal->hasNUsesOrMore(2)) {
|
||||
Value *Mul = Builder->CreateExtractValue(Call, 0, "umul.value");
|
||||
for (User *U : MulVal->users()) {
|
||||
if (U == &I || U == OtherVal)
|
||||
continue;
|
||||
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
|
||||
if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
|
||||
IC.ReplaceInstUsesWith(*TI, Mul);
|
||||
else
|
||||
TI->setOperand(0, Mul);
|
||||
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
|
||||
assert(BO->getOpcode() == Instruction::And);
|
||||
// Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
|
||||
ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
|
||||
APInt ShortMask = CI->getValue().trunc(MulWidth);
|
||||
Value *ShortAnd = Builder->CreateAnd(Mul, ShortMask);
|
||||
Instruction *Zext =
|
||||
cast<Instruction>(Builder->CreateZExt(ShortAnd, BO->getType()));
|
||||
IC.Worklist.Add(Zext);
|
||||
IC.ReplaceInstUsesWith(*BO, Zext);
|
||||
} else {
|
||||
llvm_unreachable("Unexpected Binary operation");
|
||||
}
|
||||
IC.Worklist.Add(cast<Instruction>(U));
|
||||
}
|
||||
}
|
||||
if (isa<Instruction>(OtherVal))
|
||||
IC.Worklist.Add(cast<Instruction>(OtherVal));
|
||||
|
||||
// The original icmp gets replaced with the overflow value, maybe inverted
|
||||
// depending on predicate.
|
||||
bool Inverse = false;
|
||||
switch (I.getPredicate()) {
|
||||
case ICmpInst::ICMP_NE:
|
||||
break;
|
||||
case ICmpInst::ICMP_EQ:
|
||||
Inverse = true;
|
||||
break;
|
||||
case ICmpInst::ICMP_UGT:
|
||||
case ICmpInst::ICMP_UGE:
|
||||
if (I.getOperand(0) == MulVal)
|
||||
break;
|
||||
Inverse = true;
|
||||
break;
|
||||
case ICmpInst::ICMP_ULT:
|
||||
case ICmpInst::ICMP_ULE:
|
||||
if (I.getOperand(1) == MulVal)
|
||||
break;
|
||||
Inverse = true;
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("Unexpected predicate");
|
||||
}
|
||||
if (Inverse) {
|
||||
Value *Res = Builder->CreateExtractValue(Call, 1);
|
||||
return BinaryOperator::CreateNot(Res);
|
||||
}
|
||||
|
||||
return ExtractValueInst::Create(Call, 1);
|
||||
}
|
||||
|
||||
// DemandedBitsLHSMask - When performing a comparison against a constant,
|
||||
// it is possible that not all the bits in the LHS are demanded. This helper
|
||||
// method computes the mask that IS demanded.
|
||||
|
@ -2877,6 +3107,16 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
|
|||
(Op0 == A || Op0 == B))
|
||||
if (Instruction *R = ProcessUAddIdiom(I, Op1, *this))
|
||||
return R;
|
||||
|
||||
// (zext a) * (zext b) --> llvm.umul.with.overflow.
|
||||
if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
|
||||
if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this))
|
||||
return R;
|
||||
}
|
||||
if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
|
||||
if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this))
|
||||
return R;
|
||||
}
|
||||
}
|
||||
|
||||
if (I.isEquality()) {
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
; RUN: opt -S -instcombine < %s | FileCheck %s
|
||||
|
||||
; return mul(zext x, zext y) > MAX
|
||||
define i32 @pr4917_1(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4917_1(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
; CHECK-NOT: zext i32
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
|
||||
%overflow = icmp ugt i64 %mul64, 4294967295
|
||||
; CHECK: extractvalue { i32, i1 } [[MUL]], 1
|
||||
%retval = zext i1 %overflow to i32
|
||||
ret i32 %retval
|
||||
}
|
||||
|
||||
; return mul(zext x, zext y) >= MAX+1
|
||||
define i32 @pr4917_1a(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4917_1a(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
; CHECK-NOT: zext i32
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
|
||||
%overflow = icmp uge i64 %mul64, 4294967296
|
||||
; CHECK: extractvalue { i32, i1 } [[MUL]], 1
|
||||
%retval = zext i1 %overflow to i32
|
||||
ret i32 %retval
|
||||
}
|
||||
|
||||
; mul(zext x, zext y) > MAX
|
||||
; mul(x, y) is used
|
||||
define i32 @pr4917_2(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4917_2(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
; CHECK-NOT: zext i32
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
|
||||
%overflow = icmp ugt i64 %mul64, 4294967295
|
||||
; CHECK-DAG: [[VAL:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
|
||||
%mul32 = trunc i64 %mul64 to i32
|
||||
; CHECK-DAG: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
|
||||
%retval = select i1 %overflow, i32 %mul32, i32 111
|
||||
; CHECK: select i1 [[OVFL]], i32 [[VAL]]
|
||||
ret i32 %retval
|
||||
}
|
||||
|
||||
; return mul(zext x, zext y) > MAX
|
||||
; mul is used in non-truncate
|
||||
define i64 @pr4917_3(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4917_3(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK-NOT: umul.with.overflow.i32
|
||||
%overflow = icmp ugt i64 %mul64, 4294967295
|
||||
%retval = select i1 %overflow, i64 %mul64, i64 111
|
||||
ret i64 %retval
|
||||
}
|
||||
|
||||
; return mul(zext x, zext y) <= MAX
|
||||
define i32 @pr4917_4(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4917_4(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
; CHECK-NOT: zext i32
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
|
||||
%overflow = icmp ule i64 %mul64, 4294967295
|
||||
; CHECK: extractvalue { i32, i1 } [[MUL]], 1
|
||||
; CHECK: xor
|
||||
%retval = zext i1 %overflow to i32
|
||||
ret i32 %retval
|
||||
}
|
||||
|
||||
; return mul(zext x, zext y) < MAX+1
|
||||
define i32 @pr4917_4a(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4917_4a(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
; CHECK-NOT: zext i32
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
|
||||
%overflow = icmp ult i64 %mul64, 4294967296
|
||||
; CHECK: extractvalue { i32, i1 } [[MUL]], 1
|
||||
; CHECK: xor
|
||||
%retval = zext i1 %overflow to i32
|
||||
ret i32 %retval
|
||||
}
|
||||
|
||||
; operands of mul are of different size
|
||||
define i32 @pr4917_5(i32 %x, i8 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4917_5(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i8 %y to i64
|
||||
; CHECK: [[Y:%.*]] = zext i8 %y to i32
|
||||
%mul64 = mul i64 %l, %r
|
||||
%overflow = icmp ugt i64 %mul64, 4294967295
|
||||
%mul32 = trunc i64 %mul64 to i32
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 [[Y]])
|
||||
; CHECK-DAG: [[VAL:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
|
||||
; CHECK-DAG: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
|
||||
%retval = select i1 %overflow, i32 %mul32, i32 111
|
||||
; CHECK: select i1 [[OVFL]], i32 [[VAL]]
|
||||
ret i32 %retval
|
||||
}
|
||||
|
||||
; mul(zext x, zext y) != zext trunc mul
|
||||
define i32 @pr4918_1(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4918_1(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
|
||||
%part32 = trunc i64 %mul64 to i32
|
||||
%part64 = zext i32 %part32 to i64
|
||||
%overflow = icmp ne i64 %mul64, %part64
|
||||
; CHECK: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL:%.*]], 1
|
||||
%retval = zext i1 %overflow to i32
|
||||
ret i32 %retval
|
||||
}
|
||||
|
||||
; mul(zext x, zext y) == zext trunc mul
|
||||
define i32 @pr4918_2(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4918_2(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
|
||||
%part32 = trunc i64 %mul64 to i32
|
||||
%part64 = zext i32 %part32 to i64
|
||||
%overflow = icmp eq i64 %mul64, %part64
|
||||
; CHECK: extractvalue { i32, i1 } [[MUL]]
|
||||
%retval = zext i1 %overflow to i32
|
||||
; CHECK: xor
|
||||
ret i32 %retval
|
||||
}
|
||||
|
||||
; zext trunc mul != mul(zext x, zext y)
|
||||
define i32 @pr4918_3(i32 %x, i32 %y) nounwind {
|
||||
; CHECK-LABEL: @pr4918_3(
|
||||
entry:
|
||||
%l = zext i32 %x to i64
|
||||
%r = zext i32 %y to i64
|
||||
%mul64 = mul i64 %l, %r
|
||||
; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y)
|
||||
%part32 = trunc i64 %mul64 to i32
|
||||
%part64 = zext i32 %part32 to i64
|
||||
%overflow = icmp ne i64 %part64, %mul64
|
||||
; CHECK: extractvalue { i32, i1 } [[MUL]], 1
|
||||
%retval = zext i1 %overflow to i32
|
||||
ret i32 %retval
|
||||
}
|
||||
|
Loading…
Reference in New Issue