[AArch64] Peephole rule to remove redundant cmp after cset.

Comparisons to zero or one after cset instructions can be safely
removed in examples like:

cset w9, eq          cset w9, eq
cmp  w9, #1   --->   <removed>
b.ne    .L1          b.ne    .L1

cset w9, eq          cset w9, eq
cmp  w9, #0   --->   <removed>
b.ne    .L1          b.eq    .L1

Peephole optimization to detect suitable cases and get rid of that
comparisons added.

Differential Revision: https://reviews.llvm.org/D98564
This commit is contained in:
Pavel Iliin 2021-03-09 00:18:27 +00:00
parent d8805574c1
commit 2ec16103c6
4 changed files with 219 additions and 76 deletions

View File

@ -1463,14 +1463,16 @@ bool AArch64InstrInfo::optimizeCompareInstr(
// FIXME:CmpValue has already been converted to 0 or 1 in analyzeCompare
// function.
assert((CmpValue == 0 || CmpValue == 1) && "CmpValue must be 0 or 1!");
if (CmpValue != 0 || SrcReg2 != 0)
if (SrcReg2 != 0)
return false;
// CmpInstr is a Compare instruction if destination register is not used.
if (!MRI->use_nodbg_empty(CmpInstr.getOperand(0).getReg()))
return false;
return substituteCmpToZero(CmpInstr, SrcReg, MRI);
if (!CmpValue && substituteCmpToZero(CmpInstr, SrcReg, *MRI))
return true;
return removeCmpToZeroOrOne(CmpInstr, SrcReg, CmpValue, *MRI);
}
/// Get opcode of S version of Instr.
@ -1524,13 +1526,44 @@ static unsigned sForm(MachineInstr &Instr) {
}
/// Check if AArch64::NZCV should be alive in successors of MBB.
static bool areCFlagsAliveInSuccessors(MachineBasicBlock *MBB) {
static bool areCFlagsAliveInSuccessors(const MachineBasicBlock *MBB) {
for (auto *BB : MBB->successors())
if (BB->isLiveIn(AArch64::NZCV))
return true;
return false;
}
/// \returns The condition code operand index for \p Instr if it is a branch
/// or select and -1 otherwise.
static int
findCondCodeUseOperandIdxForBranchOrSelect(const MachineInstr &Instr) {
switch (Instr.getOpcode()) {
default:
return -1;
case AArch64::Bcc: {
int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
assert(Idx >= 2);
return Idx - 2;
}
case AArch64::CSINVWr:
case AArch64::CSINVXr:
case AArch64::CSINCWr:
case AArch64::CSINCXr:
case AArch64::CSELWr:
case AArch64::CSELXr:
case AArch64::CSNEGWr:
case AArch64::CSNEGXr:
case AArch64::FCSELSrrr:
case AArch64::FCSELDrrr: {
int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
assert(Idx >= 1);
return Idx - 1;
}
}
}
namespace {
struct UsedNZCV {
@ -1556,31 +1589,10 @@ struct UsedNZCV {
/// Returns AArch64CC::Invalid if either the instruction does not use condition
/// codes or we don't optimize CmpInstr in the presence of such instructions.
static AArch64CC::CondCode findCondCodeUsedByInstr(const MachineInstr &Instr) {
switch (Instr.getOpcode()) {
default:
return AArch64CC::Invalid;
case AArch64::Bcc: {
int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
assert(Idx >= 2);
return static_cast<AArch64CC::CondCode>(Instr.getOperand(Idx - 2).getImm());
}
case AArch64::CSINVWr:
case AArch64::CSINVXr:
case AArch64::CSINCWr:
case AArch64::CSINCXr:
case AArch64::CSELWr:
case AArch64::CSELXr:
case AArch64::CSNEGWr:
case AArch64::CSNEGXr:
case AArch64::FCSELSrrr:
case AArch64::FCSELDrrr: {
int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
assert(Idx >= 1);
return static_cast<AArch64CC::CondCode>(Instr.getOperand(Idx - 1).getImm());
}
}
int CCIdx = findCondCodeUseOperandIdxForBranchOrSelect(Instr);
return CCIdx >= 0 ? static_cast<AArch64CC::CondCode>(
Instr.getOperand(CCIdx).getImm())
: AArch64CC::Invalid;
}
static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) {
@ -1627,6 +1639,41 @@ static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) {
return UsedFlags;
}
/// \returns Conditions flags used after \p CmpInstr in its MachineBB if they
/// are not containing C or V flags and NZCV flags are not alive in successors
/// of the same \p CmpInstr and \p MI parent. \returns None otherwise.
///
/// Collect instructions using that flags in \p CCUseInstrs if provided.
static Optional<UsedNZCV>
examineCFlagsUse(MachineInstr &MI, MachineInstr &CmpInstr,
const TargetRegisterInfo &TRI,
SmallVectorImpl<MachineInstr *> *CCUseInstrs = nullptr) {
MachineBasicBlock *CmpParent = CmpInstr.getParent();
if (MI.getParent() != CmpParent)
return None;
if (areCFlagsAliveInSuccessors(CmpParent))
return None;
UsedNZCV NZCVUsedAfterCmp;
for (MachineInstr &Instr : instructionsWithoutDebug(
std::next(CmpInstr.getIterator()), CmpParent->instr_end())) {
if (Instr.readsRegister(AArch64::NZCV, &TRI)) {
AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr);
if (CC == AArch64CC::Invalid) // Unsupported conditional instruction
return None;
NZCVUsedAfterCmp |= getUsedNZCV(CC);
if (CCUseInstrs)
CCUseInstrs->push_back(&Instr);
}
if (Instr.modifiesRegister(AArch64::NZCV, &TRI))
break;
}
if (NZCVUsedAfterCmp.C || NZCVUsedAfterCmp.V)
return None;
return NZCVUsedAfterCmp;
}
static bool isADDSRegImm(unsigned Opcode) {
return Opcode == AArch64::ADDSWri || Opcode == AArch64::ADDSXri;
}
@ -1646,44 +1693,21 @@ static bool isSUBSRegImm(unsigned Opcode) {
/// or if MI opcode is not the S form there must be neither defs of flags
/// nor uses of flags between MI and CmpInstr.
/// - and C/V flags are not used after CmpInstr
static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr,
const TargetRegisterInfo *TRI) {
assert(MI);
assert(sForm(*MI) != AArch64::INSTRUCTION_LIST_END);
assert(CmpInstr);
static bool canInstrSubstituteCmpInstr(MachineInstr &MI, MachineInstr &CmpInstr,
const TargetRegisterInfo &TRI) {
assert(sForm(MI) != AArch64::INSTRUCTION_LIST_END);
const unsigned CmpOpcode = CmpInstr->getOpcode();
const unsigned CmpOpcode = CmpInstr.getOpcode();
if (!isADDSRegImm(CmpOpcode) && !isSUBSRegImm(CmpOpcode))
return false;
if (MI->getParent() != CmpInstr->getParent())
return false;
if (areCFlagsAliveInSuccessors(CmpInstr->getParent()))
if (!examineCFlagsUse(MI, CmpInstr, TRI))
return false;
AccessKind AccessToCheck = AK_Write;
if (sForm(*MI) != MI->getOpcode())
if (sForm(MI) != MI.getOpcode())
AccessToCheck = AK_All;
if (areCFlagsAccessedBetweenInstrs(MI, CmpInstr, TRI, AccessToCheck))
return false;
UsedNZCV NZCVUsedAfterCmp;
for (const MachineInstr &Instr :
instructionsWithoutDebug(std::next(CmpInstr->getIterator()),
CmpInstr->getParent()->instr_end())) {
if (Instr.readsRegister(AArch64::NZCV, TRI)) {
AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr);
if (CC == AArch64CC::Invalid) // Unsupported conditional instruction
return false;
NZCVUsedAfterCmp |= getUsedNZCV(CC);
}
if (Instr.modifiesRegister(AArch64::NZCV, TRI))
break;
}
return !NZCVUsedAfterCmp.C && !NZCVUsedAfterCmp.V;
return !areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AccessToCheck);
}
/// Substitute an instruction comparing to zero with another instruction
@ -1692,20 +1716,19 @@ static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr,
/// Return true on success.
bool AArch64InstrInfo::substituteCmpToZero(
MachineInstr &CmpInstr, unsigned SrcReg,
const MachineRegisterInfo *MRI) const {
assert(MRI);
const MachineRegisterInfo &MRI) const {
// Get the unique definition of SrcReg.
MachineInstr *MI = MRI->getUniqueVRegDef(SrcReg);
MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg);
if (!MI)
return false;
const TargetRegisterInfo *TRI = &getRegisterInfo();
const TargetRegisterInfo &TRI = getRegisterInfo();
unsigned NewOpc = sForm(*MI);
if (NewOpc == AArch64::INSTRUCTION_LIST_END)
return false;
if (!canInstrSubstituteCmpInstr(MI, &CmpInstr, TRI))
if (!canInstrSubstituteCmpInstr(*MI, CmpInstr, TRI))
return false;
// Update the instruction to set NZCV.
@ -1714,7 +1737,133 @@ bool AArch64InstrInfo::substituteCmpToZero(
bool succeeded = UpdateOperandRegClass(*MI);
(void)succeeded;
assert(succeeded && "Some operands reg class are incompatible!");
MI->addRegisterDefined(AArch64::NZCV, TRI);
MI->addRegisterDefined(AArch64::NZCV, &TRI);
return true;
}
/// \returns True if \p CmpInstr can be removed.
///
/// \p IsInvertCC is true if, after removing \p CmpInstr, condition
/// codes used in \p CCUseInstrs must be inverted.
static bool canCmpInstrBeRemoved(MachineInstr &MI, MachineInstr &CmpInstr,
int CmpValue, const TargetRegisterInfo &TRI,
SmallVectorImpl<MachineInstr *> &CCUseInstrs,
bool &IsInvertCC) {
assert((CmpValue == 0 || CmpValue == 1) &&
"Only comparisons to 0 or 1 considered for removal!");
// MI is 'CSINCWr %vreg, wzr, wzr, <cc>' or 'CSINCXr %vreg, xzr, xzr, <cc>'
unsigned MIOpc = MI.getOpcode();
if (MIOpc == AArch64::CSINCWr) {
if (MI.getOperand(1).getReg() != AArch64::WZR ||
MI.getOperand(2).getReg() != AArch64::WZR)
return false;
} else if (MIOpc == AArch64::CSINCXr) {
if (MI.getOperand(1).getReg() != AArch64::XZR ||
MI.getOperand(2).getReg() != AArch64::XZR)
return false;
} else {
return false;
}
AArch64CC::CondCode MICC = findCondCodeUsedByInstr(MI);
if (MICC == AArch64CC::Invalid)
return false;
// NZCV needs to be defined
if (MI.findRegisterDefOperandIdx(AArch64::NZCV, true) != -1)
return false;
// CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0' or 'SUBS %vreg, 1'
const unsigned CmpOpcode = CmpInstr.getOpcode();
bool IsSubsRegImm = isSUBSRegImm(CmpOpcode);
if (CmpValue && !IsSubsRegImm)
return false;
if (!CmpValue && !IsSubsRegImm && !isADDSRegImm(CmpOpcode))
return false;
// MI conditions allowed: eq, ne, mi, pl
UsedNZCV MIUsedNZCV = getUsedNZCV(MICC);
if (MIUsedNZCV.C || MIUsedNZCV.V)
return false;
Optional<UsedNZCV> NZCVUsedAfterCmp =
examineCFlagsUse(MI, CmpInstr, TRI, &CCUseInstrs);
// Condition flags are not used in CmpInstr basic block successors and only
// Z or N flags allowed to be used after CmpInstr within its basic block
if (!NZCVUsedAfterCmp)
return false;
// Z or N flag used after CmpInstr must correspond to the flag used in MI
if ((MIUsedNZCV.Z && NZCVUsedAfterCmp->N) ||
(MIUsedNZCV.N && NZCVUsedAfterCmp->Z))
return false;
// If CmpInstr is comparison to zero MI conditions are limited to eq, ne
if (MIUsedNZCV.N && !CmpValue)
return false;
// There must be no defs of flags between MI and CmpInstr
if (areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AK_Write))
return false;
// Condition code is inverted in the following cases:
// 1. MI condition is ne; CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0'
// 2. MI condition is eq, pl; CmpInstr is 'SUBS %vreg, 1'
IsInvertCC = (CmpValue && (MICC == AArch64CC::EQ || MICC == AArch64CC::PL)) ||
(!CmpValue && MICC == AArch64CC::NE);
return true;
}
/// Remove comparision in csinc-cmp sequence
///
/// Examples:
/// 1. \code
/// csinc w9, wzr, wzr, ne
/// cmp w9, #0
/// b.eq
/// \endcode
/// to
/// \code
/// csinc w9, wzr, wzr, ne
/// b.ne
/// \endcode
///
/// 2. \code
/// csinc x2, xzr, xzr, mi
/// cmp x2, #1
/// b.pl
/// \endcode
/// to
/// \code
/// csinc x2, xzr, xzr, mi
/// b.pl
/// \endcode
///
/// \param CmpInstr comparison instruction
/// \return True when comparison removed
bool AArch64InstrInfo::removeCmpToZeroOrOne(
MachineInstr &CmpInstr, unsigned SrcReg, int CmpValue,
const MachineRegisterInfo &MRI) const {
MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg);
if (!MI)
return false;
const TargetRegisterInfo &TRI = getRegisterInfo();
SmallVector<MachineInstr *, 4> CCUseInstrs;
bool IsInvertCC = false;
if (!canCmpInstrBeRemoved(*MI, CmpInstr, CmpValue, TRI, CCUseInstrs,
IsInvertCC))
return false;
// Make transformation
CmpInstr.eraseFromParent();
if (IsInvertCC) {
// Invert condition codes in CmpInstr CC users
for (MachineInstr *CCUseInstr : CCUseInstrs) {
int Idx = findCondCodeUseOperandIdxForBranchOrSelect(*CCUseInstr);
assert(Idx >= 0 && "Unexpected instruction using CC.");
MachineOperand &CCOperand = CCUseInstr->getOperand(Idx);
AArch64CC::CondCode CCUse = AArch64CC::getInvertedCondCode(
static_cast<AArch64CC::CondCode>(CCOperand.getImm()));
CCOperand.setImm(CCUse);
}
}
return true;
}

View File

@ -334,7 +334,9 @@ private:
MachineBasicBlock *TBB,
ArrayRef<MachineOperand> Cond) const;
bool substituteCmpToZero(MachineInstr &CmpInstr, unsigned SrcReg,
const MachineRegisterInfo *MRI) const;
const MachineRegisterInfo &MRI) const;
bool removeCmpToZeroOrOne(MachineInstr &CmpInstr, unsigned SrcReg,
int CmpValue, const MachineRegisterInfo &MRI) const;
/// Returns an unused general-purpose register which can be used for
/// constructing an outlined call if one exists. Returns 0 otherwise.

View File

@ -12,7 +12,6 @@ body: |
; CHECK: [[DEF:%[0-9]+]]:gpr64 = IMPLICIT_DEF
; CHECK: [[SUBSXrr:%[0-9]+]]:gpr64 = SUBSXrr killed [[DEF]], [[COPY]], implicit-def $nzcv
; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 1, implicit $nzcv
; CHECK: [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri killed [[CSINCWr]], 1, 0, implicit-def $nzcv
; CHECK: Bcc 1, %bb.2, implicit $nzcv
; CHECK: B %bb.1
; CHECK: bb.1:
@ -51,8 +50,7 @@ body: |
; CHECK: [[DEF:%[0-9]+]]:gpr64 = IMPLICIT_DEF
; CHECK: [[SUBSXrr:%[0-9]+]]:gpr64 = SUBSXrr killed [[DEF]], [[COPY]], implicit-def $nzcv
; CHECK: [[CSINCXr:%[0-9]+]]:gpr64common = CSINCXr $xzr, $xzr, 1, implicit $nzcv
; CHECK: [[SUBSXri:%[0-9]+]]:gpr64 = SUBSXri killed [[CSINCXr]], 0, 0, implicit-def $nzcv
; CHECK: Bcc 0, %bb.2, implicit $nzcv
; CHECK: Bcc 1, %bb.2, implicit $nzcv
; CHECK: B %bb.1
; CHECK: bb.1:
; CHECK: successors: %bb.2(0x80000000)
@ -155,8 +153,7 @@ body: |
; CHECK: successors: %bb.1(0x40000000), %bb.2(0x40000000)
; CHECK: liveins: $nzcv
; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 1, implicit $nzcv
; CHECK: [[ADDSWri:%[0-9]+]]:gpr32 = ADDSWri killed [[CSINCWr]], 0, 0, implicit-def $nzcv
; CHECK: Bcc 1, %bb.2, implicit $nzcv
; CHECK: Bcc 0, %bb.2, implicit $nzcv
; CHECK: B %bb.1
; CHECK: bb.1:
; CHECK: successors: %bb.2(0x80000000)
@ -254,8 +251,7 @@ body: |
; CHECK: successors: %bb.1(0x40000000), %bb.2(0x40000000)
; CHECK: liveins: $nzcv
; CHECK: [[CSINCWr:%[0-9]+]]:gpr32common = CSINCWr $wzr, $wzr, 5, implicit $nzcv
; CHECK: [[SUBSWri:%[0-9]+]]:gpr32 = SUBSWri killed [[CSINCWr]], 1, 0, implicit-def $nzcv
; CHECK: Bcc 4, %bb.2, implicit $nzcv
; CHECK: Bcc 5, %bb.2, implicit $nzcv
; CHECK: B %bb.1
; CHECK: bb.1:
; CHECK: successors: %bb.2(0x80000000)

View File

@ -189,8 +189,6 @@ define half @test_select(half %a, half %b, i1 zeroext %c) #0 {
; CHECK-CVT-DAG: fcvt s1, h1
; CHECK-CVT-DAG: fcvt s0, h0
; CHECK-CVT-DAG: fcmp s2, s3
; CHECK-CVT-DAG: cset [[CC:w[0-9]+]], ne
; CHECK-CVT-DAG: cmp [[CC]], #0
; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
; CHECK-CVT-NEXT: fcvt h0, s0
; CHECK-CVT-NEXT: ret
@ -228,8 +226,6 @@ define float @test_select_cc_f32_f16(float %a, float %b, half %c, half %d) #0 {
; CHECK-CVT-DAG: fcvt s0, h0
; CHECK-CVT-DAG: fcvt s1, h1
; CHECK-CVT-DAG: fcmp s2, s3
; CHECK-CVT-DAG: cset w8, ne
; CHECK-CVT-NEXT: cmp w8, #0
; CHECK-CVT-NEXT: fcsel s0, s0, s1, ne
; CHECK-CVT-NEXT: fcvt h0, s0
; CHECK-CVT-NEXT: ret