llvm-project/llvm/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp

383 lines
11 KiB
C++

//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file contains the AArch64 / Cortex-A57 specific register allocation
// constraints for use by the PBQP register allocator.
//
// It is essentially a transcription of what is contained in
// AArch64A57FPLoadBalancing, which tries to use a balanced
// mix of odd and even D-registers when performing a critical sequence of
// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "aarch64-pbqp"
#include "AArch64PBQPRegAlloc.h"
#include "AArch64.h"
#include "AArch64RegisterInfo.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RegAllocPBQP.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
namespace {
#ifndef NDEBUG
bool isFPReg(unsigned reg) {
return AArch64::FPR32RegClass.contains(reg) ||
AArch64::FPR64RegClass.contains(reg) ||
AArch64::FPR128RegClass.contains(reg);
}
#endif
bool isOdd(unsigned reg) {
switch (reg) {
default:
llvm_unreachable("Register is not from the expected class !");
case AArch64::S1:
case AArch64::S3:
case AArch64::S5:
case AArch64::S7:
case AArch64::S9:
case AArch64::S11:
case AArch64::S13:
case AArch64::S15:
case AArch64::S17:
case AArch64::S19:
case AArch64::S21:
case AArch64::S23:
case AArch64::S25:
case AArch64::S27:
case AArch64::S29:
case AArch64::S31:
case AArch64::D1:
case AArch64::D3:
case AArch64::D5:
case AArch64::D7:
case AArch64::D9:
case AArch64::D11:
case AArch64::D13:
case AArch64::D15:
case AArch64::D17:
case AArch64::D19:
case AArch64::D21:
case AArch64::D23:
case AArch64::D25:
case AArch64::D27:
case AArch64::D29:
case AArch64::D31:
case AArch64::Q1:
case AArch64::Q3:
case AArch64::Q5:
case AArch64::Q7:
case AArch64::Q9:
case AArch64::Q11:
case AArch64::Q13:
case AArch64::Q15:
case AArch64::Q17:
case AArch64::Q19:
case AArch64::Q21:
case AArch64::Q23:
case AArch64::Q25:
case AArch64::Q27:
case AArch64::Q29:
case AArch64::Q31:
return true;
case AArch64::S0:
case AArch64::S2:
case AArch64::S4:
case AArch64::S6:
case AArch64::S8:
case AArch64::S10:
case AArch64::S12:
case AArch64::S14:
case AArch64::S16:
case AArch64::S18:
case AArch64::S20:
case AArch64::S22:
case AArch64::S24:
case AArch64::S26:
case AArch64::S28:
case AArch64::S30:
case AArch64::D0:
case AArch64::D2:
case AArch64::D4:
case AArch64::D6:
case AArch64::D8:
case AArch64::D10:
case AArch64::D12:
case AArch64::D14:
case AArch64::D16:
case AArch64::D18:
case AArch64::D20:
case AArch64::D22:
case AArch64::D24:
case AArch64::D26:
case AArch64::D28:
case AArch64::D30:
case AArch64::Q0:
case AArch64::Q2:
case AArch64::Q4:
case AArch64::Q6:
case AArch64::Q8:
case AArch64::Q10:
case AArch64::Q12:
case AArch64::Q14:
case AArch64::Q16:
case AArch64::Q18:
case AArch64::Q20:
case AArch64::Q22:
case AArch64::Q24:
case AArch64::Q26:
case AArch64::Q28:
case AArch64::Q30:
return false;
}
}
bool haveSameParity(unsigned reg1, unsigned reg2) {
assert(isFPReg(reg1) && "Expecting an FP register for reg1");
assert(isFPReg(reg2) && "Expecting an FP register for reg2");
return isOdd(reg1) == isOdd(reg2);
}
}
bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
unsigned Ra) {
if (Rd == Ra)
return false;
LiveIntervals &LIs = G.getMetadata().LIS;
if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
LLVM_DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
<< '\n');
LLVM_DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
<< '\n');
return false;
}
PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
&G.getNodeMetadata(node1).getAllowedRegs();
const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
&G.getNodeMetadata(node2).getAllowedRegs();
PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
// The edge does not exist. Create one with the appropriate interference
// costs.
if (edge == G.invalidEdgeId()) {
const LiveInterval &ld = LIs.getInterval(Rd);
const LiveInterval &la = LIs.getInterval(Ra);
bool livesOverlap = ld.overlaps(la);
PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
vRaAllowed->size() + 1, 0);
for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
unsigned pRd = (*vRdAllowed)[i];
for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRaAllowed)[j];
if (livesOverlap && TRI->regsOverlap(pRd, pRa))
costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
else
costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
}
}
G.addEdge(node1, node2, std::move(costs));
return true;
}
if (G.getEdgeNode1Id(edge) == node2) {
std::swap(node1, node2);
std::swap(vRdAllowed, vRaAllowed);
}
// Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
unsigned pRd = (*vRdAllowed)[i];
// Get the maximum cost (excluding unallocatable reg) for same parity
// registers
PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRaAllowed)[j];
if (haveSameParity(pRd, pRa))
if (costs[i + 1][j + 1] !=
std::numeric_limits<PBQP::PBQPNum>::infinity() &&
costs[i + 1][j + 1] > sameParityMax)
sameParityMax = costs[i + 1][j + 1];
}
// Ensure all registers with a different parity have a higher cost
// than sameParityMax
for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRaAllowed)[j];
if (!haveSameParity(pRd, pRa))
if (sameParityMax > costs[i + 1][j + 1])
costs[i + 1][j + 1] = sameParityMax + 1.0;
}
}
G.updateEdgeCosts(edge, std::move(costs));
return true;
}
void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
unsigned Ra) {
LiveIntervals &LIs = G.getMetadata().LIS;
// Do some Chain management
if (Chains.count(Ra)) {
if (Rd != Ra) {
LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
<< " to " << printReg(Rd, TRI) << '\n';);
Chains.remove(Ra);
Chains.insert(Rd);
}
} else {
LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
<< '\n';);
Chains.insert(Rd);
}
PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
const LiveInterval &ld = LIs.getInterval(Rd);
for (auto r : Chains) {
// Skip self
if (r == Rd)
continue;
const LiveInterval &lr = LIs.getInterval(r);
if (ld.overlaps(lr)) {
const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
&G.getNodeMetadata(node1).getAllowedRegs();
PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
&G.getNodeMetadata(node2).getAllowedRegs();
PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
assert(edge != G.invalidEdgeId() &&
"PBQP error ! The edge should exist !");
LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
if (G.getEdgeNode1Id(edge) == node2) {
std::swap(node1, node2);
std::swap(vRdAllowed, vRrAllowed);
}
// Enforce that cost is higher with all other Chains of the same parity
PBQP::Matrix costs(G.getEdgeCosts(edge));
for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
unsigned pRd = (*vRdAllowed)[i];
// Get the maximum cost (excluding unallocatable reg) for all other
// parity registers
PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRrAllowed)[j];
if (!haveSameParity(pRd, pRa))
if (costs[i + 1][j + 1] !=
std::numeric_limits<PBQP::PBQPNum>::infinity() &&
costs[i + 1][j + 1] > sameParityMax)
sameParityMax = costs[i + 1][j + 1];
}
// Ensure all registers with same parity have a higher cost
// than sameParityMax
for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRrAllowed)[j];
if (haveSameParity(pRd, pRa))
if (sameParityMax > costs[i + 1][j + 1])
costs[i + 1][j + 1] = sameParityMax + 1.0;
}
}
G.updateEdgeCosts(edge, std::move(costs));
}
}
}
static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
const MachineInstr &MI) {
const LiveInterval &LI = LIs.getInterval(reg);
SlotIndex SI = LIs.getInstructionIndex(MI);
return LI.expiredAt(SI);
}
void A57ChainingConstraint::apply(PBQPRAGraph &G) {
const MachineFunction &MF = G.getMetadata().MF;
LiveIntervals &LIs = G.getMetadata().LIS;
TRI = MF.getSubtarget().getRegisterInfo();
LLVM_DEBUG(MF.dump());
for (const auto &MBB: MF) {
Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
for (const auto &MI: MBB) {
// Forget Chains which have expired
for (auto r : Chains) {
SmallVector<unsigned, 8> toDel;
if(regJustKilledBefore(LIs, r, MI)) {
LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
MI.print(dbgs()););
toDel.push_back(r);
}
while (!toDel.empty()) {
Chains.remove(toDel.back());
toDel.pop_back();
}
}
switch (MI.getOpcode()) {
case AArch64::FMSUBSrrr:
case AArch64::FMADDSrrr:
case AArch64::FNMSUBSrrr:
case AArch64::FNMADDSrrr:
case AArch64::FMSUBDrrr:
case AArch64::FMADDDrrr:
case AArch64::FNMSUBDrrr:
case AArch64::FNMADDDrrr: {
unsigned Rd = MI.getOperand(0).getReg();
unsigned Ra = MI.getOperand(3).getReg();
if (addIntraChainConstraint(G, Rd, Ra))
addInterChainConstraint(G, Rd, Ra);
break;
}
case AArch64::FMLAv2f32:
case AArch64::FMLSv2f32: {
unsigned Rd = MI.getOperand(0).getReg();
addInterChainConstraint(G, Rd, Rd);
break;
}
default:
break;
}
}
}
}