[instcombine] Fold overflow check using umulo to comparison

If we have a umul.with.overflow where the multiply result is not used and one of the operands is a constant, we can perform the overflow check cheaper with a comparison then by performing the multiply and extracting the overflow flag.

(Noticed when looking at the conditions SCEV emits for overflow checks.)

Differential Revision: https://reviews.llvm.org/D104665
This commit is contained in:
Philip Reames 2021-06-25 10:24:10 -07:00
parent 9eaf0d120d
commit 2cd23eb243
2 changed files with 33 additions and 16 deletions

View File

@ -3083,13 +3083,36 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
return BinaryOperator::Create(BinOp, LHS, RHS); return BinaryOperator::Create(BinOp, LHS, RHS);
} }
// If the normal result of the add is dead, and the RHS is a constant, assert(*EV.idx_begin() == 1 &&
// we can transform this into a range comparison. "unexpected extract index for overflow inst");
// overflow = uadd a, -4 --> overflow = icmp ugt a, 3
if (WO->getIntrinsicID() == Intrinsic::uadd_with_overflow) // If the normal result of the computation is dead, and the RHS is a
// constant, we can transform this into a range comparison for many cases.
// TODO: We can generalize these for non-constant rhs when the newly
// formed expressions are known to simplify. Constants are merely one
// such case.
// TODO: Handle vector splats.
switch (WO->getIntrinsicID()) {
default:
break;
case Intrinsic::uadd_with_overflow:
// overflow = uadd a, -4 --> overflow = icmp ugt a, 3
if (ConstantInt *CI = dyn_cast<ConstantInt>(WO->getRHS())) if (ConstantInt *CI = dyn_cast<ConstantInt>(WO->getRHS()))
return new ICmpInst(ICmpInst::ICMP_UGT, WO->getLHS(), return new ICmpInst(ICmpInst::ICMP_UGT, WO->getLHS(),
ConstantExpr::getNot(CI)); ConstantExpr::getNot(CI));
break;
case Intrinsic::umul_with_overflow:
// overflow for umul a, C --> a > UINT_MAX udiv C
// (unless C == 0, in which case no overflow ever occurs)
if (ConstantInt *CI = dyn_cast<ConstantInt>(WO->getRHS())) {
assert(!CI->isZero() && "handled by instruction simplify");
auto UMax = APInt::getMaxValue(CI->getType()->getBitWidth());
auto *Op =
ConstantExpr::getUDiv(ConstantInt::get(CI->getType(), UMax), CI);
return new ICmpInst(ICmpInst::ICMP_UGT, WO->getLHS(), Op);
}
break;
};
} }
} }
if (LoadInst *L = dyn_cast<LoadInst>(Agg)) if (LoadInst *L = dyn_cast<LoadInst>(Agg))

View File

@ -35,8 +35,7 @@ define i1 @test_constant1(i8 %a) {
define i1 @test_constant2(i8 %a) { define i1 @test_constant2(i8 %a) {
; CHECK-LABEL: @test_constant2( ; CHECK-LABEL: @test_constant2(
; CHECK-NEXT: [[RES:%.*]] = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[A:%.*]], i8 2) ; CHECK-NEXT: [[OVERFLOW:%.*]] = icmp slt i8 [[A:%.*]], 0
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[RES]], 1
; CHECK-NEXT: ret i1 [[OVERFLOW]] ; CHECK-NEXT: ret i1 [[OVERFLOW]]
; ;
%res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 2) %res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 2)
@ -46,8 +45,7 @@ define i1 @test_constant2(i8 %a) {
define i1 @test_constant3(i8 %a) { define i1 @test_constant3(i8 %a) {
; CHECK-LABEL: @test_constant3( ; CHECK-LABEL: @test_constant3(
; CHECK-NEXT: [[RES:%.*]] = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[A:%.*]], i8 3) ; CHECK-NEXT: [[OVERFLOW:%.*]] = icmp ugt i8 [[A:%.*]], 85
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[RES]], 1
; CHECK-NEXT: ret i1 [[OVERFLOW]] ; CHECK-NEXT: ret i1 [[OVERFLOW]]
; ;
%res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 3) %res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 3)
@ -57,8 +55,7 @@ define i1 @test_constant3(i8 %a) {
define i1 @test_constant4(i8 %a) { define i1 @test_constant4(i8 %a) {
; CHECK-LABEL: @test_constant4( ; CHECK-LABEL: @test_constant4(
; CHECK-NEXT: [[RES:%.*]] = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[A:%.*]], i8 4) ; CHECK-NEXT: [[OVERFLOW:%.*]] = icmp ugt i8 [[A:%.*]], 63
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[RES]], 1
; CHECK-NEXT: ret i1 [[OVERFLOW]] ; CHECK-NEXT: ret i1 [[OVERFLOW]]
; ;
%res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 4) %res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 4)
@ -69,8 +66,7 @@ define i1 @test_constant4(i8 %a) {
define i1 @test_constant127(i8 %a) { define i1 @test_constant127(i8 %a) {
; CHECK-LABEL: @test_constant127( ; CHECK-LABEL: @test_constant127(
; CHECK-NEXT: [[RES:%.*]] = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[A:%.*]], i8 127) ; CHECK-NEXT: [[OVERFLOW:%.*]] = icmp ugt i8 [[A:%.*]], 2
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[RES]], 1
; CHECK-NEXT: ret i1 [[OVERFLOW]] ; CHECK-NEXT: ret i1 [[OVERFLOW]]
; ;
%res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 127) %res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 127)
@ -80,8 +76,7 @@ define i1 @test_constant127(i8 %a) {
define i1 @test_constant128(i8 %a) { define i1 @test_constant128(i8 %a) {
; CHECK-LABEL: @test_constant128( ; CHECK-LABEL: @test_constant128(
; CHECK-NEXT: [[RES:%.*]] = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[A:%.*]], i8 -128) ; CHECK-NEXT: [[OVERFLOW:%.*]] = icmp ugt i8 [[A:%.*]], 1
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[RES]], 1
; CHECK-NEXT: ret i1 [[OVERFLOW]] ; CHECK-NEXT: ret i1 [[OVERFLOW]]
; ;
%res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 128) %res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 128)
@ -91,8 +86,7 @@ define i1 @test_constant128(i8 %a) {
define i1 @test_constant255(i8 %a) { define i1 @test_constant255(i8 %a) {
; CHECK-LABEL: @test_constant255( ; CHECK-LABEL: @test_constant255(
; CHECK-NEXT: [[RES:%.*]] = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[A:%.*]], i8 -1) ; CHECK-NEXT: [[OVERFLOW:%.*]] = icmp ugt i8 [[A:%.*]], 1
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[RES]], 1
; CHECK-NEXT: ret i1 [[OVERFLOW]] ; CHECK-NEXT: ret i1 [[OVERFLOW]]
; ;
%res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 255) %res = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 255)