[NFCI] Introduce `ICmpInst::compare()` and use it where appropriate

As noted in https://reviews.llvm.org/D90924#inline-1076197
apparently this is a pretty common pattern,
let's not repeat it yet again, but have it in a common place.

There may be some more places where it could be used,
but these are the most obvious ones.
This commit is contained in:
Roman Lebedev 2021-10-30 17:36:23 +03:00
parent 2c4a9e830c
commit 25043c8276
No known key found for this signature in database
GPG Key ID: 083C3EBB4A1689E0
9 changed files with 74 additions and 110 deletions

View File

@ -104,9 +104,12 @@ ISD::CondCode getFCmpCodeWithoutNaN(ISD::CondCode CC);
/// getICmpCondCode - Return the ISD condition code corresponding to
/// the given LLVM IR integer condition code.
///
ISD::CondCode getICmpCondCode(ICmpInst::Predicate Pred);
/// getICmpCondCode - Return the LLVM IR integer condition code
/// corresponding to the given ISD integer condition code.
ICmpInst::Predicate getICmpCondCode(ISD::CondCode Pred);
/// Test if the given instruction is in a position to be optimized
/// with a tail-call. This roughly means that it's in a block with
/// a return and there's nothing that needs to be scheduled

View File

@ -1349,6 +1349,10 @@ public:
Op<0>().swap(Op<1>());
}
/// Return result of `LHS Pred RHS` comparison.
static bool compare(const APInt &LHS, const APInt &RHS,
ICmpInst::Predicate Pred);
// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const Instruction *I) {
return I->getOpcode() == Instruction::ICmp;

View File

@ -593,32 +593,7 @@ inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() {
struct icmp_pred_with_threshold {
ICmpInst::Predicate Pred;
const APInt *Thr;
bool isValue(const APInt &C) {
switch (Pred) {
case ICmpInst::Predicate::ICMP_EQ:
return C.eq(*Thr);
case ICmpInst::Predicate::ICMP_NE:
return C.ne(*Thr);
case ICmpInst::Predicate::ICMP_UGT:
return C.ugt(*Thr);
case ICmpInst::Predicate::ICMP_UGE:
return C.uge(*Thr);
case ICmpInst::Predicate::ICMP_ULT:
return C.ult(*Thr);
case ICmpInst::Predicate::ICMP_ULE:
return C.ule(*Thr);
case ICmpInst::Predicate::ICMP_SGT:
return C.sgt(*Thr);
case ICmpInst::Predicate::ICMP_SGE:
return C.sge(*Thr);
case ICmpInst::Predicate::ICMP_SLT:
return C.slt(*Thr);
case ICmpInst::Predicate::ICMP_SLE:
return C.sle(*Thr);
default:
llvm_unreachable("Unhandled ICmp predicate");
}
}
bool isValue(const APInt &C) { return ICmpInst::compare(C, *Thr, Pred); }
};
/// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
/// to Threshold. For vectors, this includes constants with undefined elements.

View File

@ -221,9 +221,6 @@ ISD::CondCode llvm::getFCmpCodeWithoutNaN(ISD::CondCode CC) {
}
}
/// getICmpCondCode - Return the ISD condition code corresponding to
/// the given LLVM IR integer condition code.
///
ISD::CondCode llvm::getICmpCondCode(ICmpInst::Predicate Pred) {
switch (Pred) {
case ICmpInst::ICMP_EQ: return ISD::SETEQ;
@ -241,6 +238,33 @@ ISD::CondCode llvm::getICmpCondCode(ICmpInst::Predicate Pred) {
}
}
ICmpInst::Predicate llvm::getICmpCondCode(ISD::CondCode Pred) {
switch (Pred) {
case ISD::SETEQ:
return ICmpInst::ICMP_EQ;
case ISD::SETNE:
return ICmpInst::ICMP_NE;
case ISD::SETLE:
return ICmpInst::ICMP_SLE;
case ISD::SETULE:
return ICmpInst::ICMP_ULE;
case ISD::SETGE:
return ICmpInst::ICMP_SGE;
case ISD::SETUGE:
return ICmpInst::ICMP_UGE;
case ISD::SETLT:
return ICmpInst::ICMP_SLT;
case ISD::SETULT:
return ICmpInst::ICMP_ULT;
case ISD::SETGT:
return ICmpInst::ICMP_SGT;
case ISD::SETUGT:
return ICmpInst::ICMP_UGT;
default:
llvm_unreachable("Invalid ISD integer condition code!");
}
}
static bool isNoopBitcast(Type *T1, Type *T2,
const TargetLoweringBase& TLI) {
return T1 == T2 || (T1->isPointerTy() && T2->isPointerTy()) ||

View File

@ -28,6 +28,7 @@
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/FunctionLoweringInfo.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
@ -2312,19 +2313,8 @@ SDValue SelectionDAG::FoldSetCC(EVT VT, SDValue N1, SDValue N2,
if (ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1)) {
const APInt &C1 = N1C->getAPIntValue();
switch (Cond) {
default: llvm_unreachable("Unknown integer setcc!");
case ISD::SETEQ: return getBoolConstant(C1 == C2, dl, VT, OpVT);
case ISD::SETNE: return getBoolConstant(C1 != C2, dl, VT, OpVT);
case ISD::SETULT: return getBoolConstant(C1.ult(C2), dl, VT, OpVT);
case ISD::SETUGT: return getBoolConstant(C1.ugt(C2), dl, VT, OpVT);
case ISD::SETULE: return getBoolConstant(C1.ule(C2), dl, VT, OpVT);
case ISD::SETUGE: return getBoolConstant(C1.uge(C2), dl, VT, OpVT);
case ISD::SETLT: return getBoolConstant(C1.slt(C2), dl, VT, OpVT);
case ISD::SETGT: return getBoolConstant(C1.sgt(C2), dl, VT, OpVT);
case ISD::SETLE: return getBoolConstant(C1.sle(C2), dl, VT, OpVT);
case ISD::SETGE: return getBoolConstant(C1.sge(C2), dl, VT, OpVT);
}
return getBoolConstant(ICmpInst::compare(C1, C2, getICmpCondCode(Cond)),
dl, VT, OpVT);
}
}

View File

@ -1792,19 +1792,8 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2)) {
const APInt &V1 = cast<ConstantInt>(C1)->getValue();
const APInt &V2 = cast<ConstantInt>(C2)->getValue();
switch (pred) {
default: llvm_unreachable("Invalid ICmp Predicate");
case ICmpInst::ICMP_EQ: return ConstantInt::get(ResultTy, V1 == V2);
case ICmpInst::ICMP_NE: return ConstantInt::get(ResultTy, V1 != V2);
case ICmpInst::ICMP_SLT: return ConstantInt::get(ResultTy, V1.slt(V2));
case ICmpInst::ICMP_SGT: return ConstantInt::get(ResultTy, V1.sgt(V2));
case ICmpInst::ICMP_SLE: return ConstantInt::get(ResultTy, V1.sle(V2));
case ICmpInst::ICMP_SGE: return ConstantInt::get(ResultTy, V1.sge(V2));
case ICmpInst::ICMP_ULT: return ConstantInt::get(ResultTy, V1.ult(V2));
case ICmpInst::ICMP_UGT: return ConstantInt::get(ResultTy, V1.ugt(V2));
case ICmpInst::ICMP_ULE: return ConstantInt::get(ResultTy, V1.ule(V2));
case ICmpInst::ICMP_UGE: return ConstantInt::get(ResultTy, V1.uge(V2));
}
return ConstantInt::get(
ResultTy, ICmpInst::compare(V1, V2, (ICmpInst::Predicate)pred));
} else if (isa<ConstantFP>(C1) && isa<ConstantFP>(C2)) {
const APFloat &C1V = cast<ConstantFP>(C1)->getValueAPF();
const APFloat &C2V = cast<ConstantFP>(C2)->getValueAPF();

View File

@ -4055,6 +4055,35 @@ bool CmpInst::isSigned(Predicate predicate) {
}
}
bool ICmpInst::compare(const APInt &LHS, const APInt &RHS,
ICmpInst::Predicate Pred) {
assert(ICmpInst::isIntPredicate(Pred) && "Only for integer predicates!");
switch (Pred) {
case ICmpInst::Predicate::ICMP_EQ:
return LHS.eq(RHS);
case ICmpInst::Predicate::ICMP_NE:
return LHS.ne(RHS);
case ICmpInst::Predicate::ICMP_UGT:
return LHS.ugt(RHS);
case ICmpInst::Predicate::ICMP_UGE:
return LHS.uge(RHS);
case ICmpInst::Predicate::ICMP_ULT:
return LHS.ult(RHS);
case ICmpInst::Predicate::ICMP_ULE:
return LHS.ule(RHS);
case ICmpInst::Predicate::ICMP_SGT:
return LHS.sgt(RHS);
case ICmpInst::Predicate::ICMP_SGE:
return LHS.sge(RHS);
case ICmpInst::Predicate::ICMP_SLT:
return LHS.slt(RHS);
case ICmpInst::Predicate::ICMP_SLE:
return LHS.sle(RHS);
default:
llvm_unreachable("Unexpected non-integer predicate.");
};
}
CmpInst::Predicate CmpInst::getFlippedSignednessPredicate(Predicate pred) {
assert(CmpInst::isRelational(pred) &&
"Call only with non-equality predicates!");

View File

@ -8651,31 +8651,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
static bool calculateICmpInst(const ICmpInst *ICI, const APInt &LHS,
const APInt &RHS) {
ICmpInst::Predicate Pred = ICI->getPredicate();
switch (Pred) {
case ICmpInst::ICMP_UGT:
return LHS.ugt(RHS);
case ICmpInst::ICMP_SGT:
return LHS.sgt(RHS);
case ICmpInst::ICMP_EQ:
return LHS.eq(RHS);
case ICmpInst::ICMP_UGE:
return LHS.uge(RHS);
case ICmpInst::ICMP_SGE:
return LHS.sge(RHS);
case ICmpInst::ICMP_ULT:
return LHS.ult(RHS);
case ICmpInst::ICMP_SLT:
return LHS.slt(RHS);
case ICmpInst::ICMP_NE:
return LHS.ne(RHS);
case ICmpInst::ICMP_ULE:
return LHS.ule(RHS);
case ICmpInst::ICMP_SLE:
return LHS.sle(RHS);
default:
llvm_unreachable("Invalid ICmp predicate!");
}
return ICmpInst::compare(LHS, RHS, ICI->getPredicate());
}
static APInt calculateCastInst(const CastInst *CI, const APInt &Src,

View File

@ -1557,41 +1557,15 @@ TEST(ConstantRange, MakeSatisfyingICmpRegion) {
ConstantRange(APInt(8, 4), APInt(8, -128)));
}
static bool icmp(CmpInst::Predicate Pred, const APInt &LHS, const APInt &RHS) {
switch (Pred) {
case CmpInst::Predicate::ICMP_EQ:
return LHS.eq(RHS);
case CmpInst::Predicate::ICMP_NE:
return LHS.ne(RHS);
case CmpInst::Predicate::ICMP_UGT:
return LHS.ugt(RHS);
case CmpInst::Predicate::ICMP_UGE:
return LHS.uge(RHS);
case CmpInst::Predicate::ICMP_ULT:
return LHS.ult(RHS);
case CmpInst::Predicate::ICMP_ULE:
return LHS.ule(RHS);
case CmpInst::Predicate::ICMP_SGT:
return LHS.sgt(RHS);
case CmpInst::Predicate::ICMP_SGE:
return LHS.sge(RHS);
case CmpInst::Predicate::ICMP_SLT:
return LHS.slt(RHS);
case CmpInst::Predicate::ICMP_SLE:
return LHS.sle(RHS);
default:
llvm_unreachable("Not an ICmp predicate!");
}
}
void ICmpTestImpl(CmpInst::Predicate Pred) {
unsigned Bits = 4;
EnumerateTwoConstantRanges(
Bits, [&](const ConstantRange &CR1, const ConstantRange &CR2) {
bool Exhaustive = true;
ForeachNumInConstantRange(CR1, [&](const APInt &N1) {
ForeachNumInConstantRange(
CR2, [&](const APInt &N2) { Exhaustive &= icmp(Pred, N1, N2); });
ForeachNumInConstantRange(CR2, [&](const APInt &N2) {
Exhaustive &= ICmpInst::compare(N1, N2, Pred);
});
});
EXPECT_EQ(CR1.icmp(Pred, CR2), Exhaustive);
});