forked from OSchip/llvm-project
405 lines
15 KiB
C++
405 lines
15 KiB
C++
//===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
/// \file Pass to pre-config the shapes of AMX registers
|
|
/// AMX register needs to be configured before use. The shapes of AMX register
|
|
/// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
|
|
///
|
|
/// The instruction ldtilecfg is used to config the shapes. It must be reachable
|
|
/// for all variable shapes. ldtilecfg will be inserted more than once if we
|
|
/// cannot find a dominating point for all AMX instructions.
|
|
///
|
|
/// The configure register is caller saved according to ABI. We need to insert
|
|
/// ldtilecfg again after the call instruction if callee clobbers any AMX
|
|
/// registers.
|
|
///
|
|
/// This pass calculates all points that ldtilecfg need to be inserted to and
|
|
/// insert them. It reports error if the reachability conditions aren't met.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "X86.h"
|
|
#include "X86InstrBuilder.h"
|
|
#include "X86RegisterInfo.h"
|
|
#include "X86Subtarget.h"
|
|
#include "llvm/CodeGen/MachineFunctionPass.h"
|
|
#include "llvm/CodeGen/MachineInstr.h"
|
|
#include "llvm/CodeGen/MachineLoopInfo.h"
|
|
#include "llvm/CodeGen/MachineRegisterInfo.h"
|
|
#include "llvm/CodeGen/Passes.h"
|
|
#include "llvm/CodeGen/TargetInstrInfo.h"
|
|
#include "llvm/CodeGen/TargetRegisterInfo.h"
|
|
#include "llvm/InitializePasses.h"
|
|
|
|
using namespace llvm;
|
|
|
|
#define DEBUG_TYPE "tile-pre-config"
|
|
#define REPORT_CONFIG_FAIL \
|
|
report_fatal_error( \
|
|
MF.getName() + \
|
|
": Failed to config tile register, please define the shape earlier");
|
|
|
|
namespace {
|
|
|
|
struct MIRef {
|
|
MachineInstr *MI = nullptr;
|
|
MachineBasicBlock *MBB = nullptr;
|
|
// A virtual position for instruction that will be inserted after MI.
|
|
size_t Pos = 0;
|
|
MIRef() = default;
|
|
MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
|
|
for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
|
|
++I, ++Pos)
|
|
MI = &*I;
|
|
}
|
|
MIRef(MachineInstr *MI)
|
|
: MI(MI), MBB(MI->getParent()),
|
|
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
|
|
MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
|
|
: MI(MI), MBB(MBB),
|
|
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
|
|
MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
|
|
: MI(MI), MBB(MBB), Pos(Pos) {}
|
|
operator bool() const { return MBB != nullptr; }
|
|
bool operator==(const MIRef &RHS) const {
|
|
return MI == RHS.MI && MBB == RHS.MBB;
|
|
}
|
|
bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
|
|
bool operator<(const MIRef &RHS) const {
|
|
// Comparison between different BBs happens when inserting a MIRef into set.
|
|
// So we compare MBB first to make the insertion happy.
|
|
return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
|
|
}
|
|
bool operator>(const MIRef &RHS) const {
|
|
// Comparison between different BBs happens when inserting a MIRef into set.
|
|
// So we compare MBB first to make the insertion happy.
|
|
return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
|
|
}
|
|
};
|
|
|
|
struct BBInfo {
|
|
MIRef FirstAMX;
|
|
MIRef LastCall;
|
|
bool HasAMXRegLiveIn = false;
|
|
bool TileCfgForbidden = false;
|
|
bool NeedTileCfgLiveIn = false;
|
|
};
|
|
|
|
class X86PreTileConfig : public MachineFunctionPass {
|
|
MachineRegisterInfo *MRI;
|
|
const MachineLoopInfo *MLI;
|
|
SmallSet<MachineInstr *, 8> DefVisited;
|
|
DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
|
|
DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
|
|
|
|
/// Check if the callee will clobber AMX registers.
|
|
bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
|
|
auto Iter = llvm::find_if(
|
|
MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
|
|
if (Iter == MI.operands_end())
|
|
return false;
|
|
UsableRegs.clearBitsInMask(Iter->getRegMask());
|
|
return !UsableRegs.none();
|
|
}
|
|
|
|
/// Check if MI is AMX pseudo instruction.
|
|
bool isAMXInstruction(MachineInstr &MI) {
|
|
if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
|
|
return false;
|
|
MachineOperand &MO = MI.getOperand(0);
|
|
// We can simply check if it is AMX instruction by its def.
|
|
// But we should exclude old API which uses physical registers.
|
|
if (MO.isReg() && MO.getReg().isVirtual() &&
|
|
MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
|
|
collectShapeInfo(MI);
|
|
return true;
|
|
}
|
|
// PTILESTOREDV is the only exception that doesn't def a AMX register.
|
|
return MI.getOpcode() == X86::PTILESTOREDV;
|
|
}
|
|
|
|
/// Check if it is an edge from loop bottom to loop head.
|
|
bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
|
|
if (!MLI->isLoopHeader(Header))
|
|
return false;
|
|
auto *ML = MLI->getLoopFor(Header);
|
|
if (ML->contains(Bottom) && ML->isLoopLatch(Bottom))
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Collect the shape def information for later use.
|
|
void collectShapeInfo(MachineInstr &MI);
|
|
|
|
/// Try to hoist shapes definded below AMX instructions.
|
|
bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
|
|
MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
|
|
auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
|
|
auto InsertPoint = FirstAMX.MI->getIterator();
|
|
for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
|
|
// Do not hoist instructions that access memory.
|
|
if (I->MI->mayLoadOrStore())
|
|
return false;
|
|
for (auto &MO : I->MI->operands()) {
|
|
if (MO.isDef())
|
|
continue;
|
|
// Do not hoist instructions if the sources' def under AMX instruction.
|
|
// TODO: We can handle isMoveImmediate MI here.
|
|
if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
|
|
return false;
|
|
// TODO: Maybe need more checks here.
|
|
}
|
|
MBB->insert(InsertPoint, I->MI->removeFromParent());
|
|
}
|
|
// We only need to mark the last shape in the BB now.
|
|
Shapes.clear();
|
|
Shapes.push_back(MIRef(&*--InsertPoint, MBB));
|
|
return true;
|
|
}
|
|
|
|
public:
|
|
X86PreTileConfig() : MachineFunctionPass(ID) {}
|
|
|
|
/// Return the pass name.
|
|
StringRef getPassName() const override {
|
|
return "Tile Register Pre-configure";
|
|
}
|
|
|
|
/// X86PreTileConfig analysis usage.
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
AU.setPreservesAll();
|
|
AU.addRequired<MachineLoopInfo>();
|
|
MachineFunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
/// Clear MF related structures.
|
|
void releaseMemory() override {
|
|
ShapeBBs.clear();
|
|
DefVisited.clear();
|
|
BBVisitedInfo.clear();
|
|
}
|
|
|
|
/// Perform ldtilecfg instructions inserting.
|
|
bool runOnMachineFunction(MachineFunction &MF) override;
|
|
|
|
static char ID;
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
char X86PreTileConfig::ID = 0;
|
|
|
|
INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
|
|
"Tile Register Pre-configure", false, false)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
|
|
INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
|
|
"Tile Register Pre-configure", false, false)
|
|
|
|
void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
|
|
auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
|
|
MIRef MIR(MI, MBB);
|
|
auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
|
|
if (I == ShapeBBs[MBB].end() || *I != MIR)
|
|
ShapeBBs[MBB].insert(I, MIR);
|
|
};
|
|
|
|
SmallVector<Register, 8> WorkList(
|
|
{MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
|
|
while (!WorkList.empty()) {
|
|
Register R = WorkList.pop_back_val();
|
|
MachineInstr *DefMI = MRI->getVRegDef(R);
|
|
assert(DefMI && "R must has one define instruction");
|
|
MachineBasicBlock *DefMBB = DefMI->getParent();
|
|
if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
|
|
continue;
|
|
if (DefMI->isPHI()) {
|
|
for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
|
|
if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
|
|
RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
|
|
else
|
|
WorkList.push_back(DefMI->getOperand(I).getReg());
|
|
} else {
|
|
RecordShape(DefMI, DefMBB);
|
|
}
|
|
}
|
|
}
|
|
|
|
bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
|
|
const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
|
|
const TargetInstrInfo *TII = ST.getInstrInfo();
|
|
const TargetRegisterInfo *TRI = ST.getRegisterInfo();
|
|
const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
|
|
|
|
BitVector AMXRegs(TRI->getNumRegs());
|
|
for (unsigned I = 0; I < RC->getNumRegs(); I++)
|
|
AMXRegs.set(X86::TMM0 + I);
|
|
|
|
// Iterate MF to collect information.
|
|
MRI = &MF.getRegInfo();
|
|
MLI = &getAnalysis<MachineLoopInfo>();
|
|
SmallSet<MIRef, 8> CfgNeedInsert;
|
|
SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
|
|
for (auto &MBB : MF) {
|
|
size_t Pos = 0;
|
|
for (auto &MI : MBB) {
|
|
++Pos;
|
|
if (isAMXInstruction(MI)) {
|
|
// If there's call before the AMX, we need to reload tile config.
|
|
if (BBVisitedInfo[&MBB].LastCall)
|
|
CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
|
|
else // Otherwise, we need tile config to live in this BB.
|
|
BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
|
|
// Always record the first AMX in case there's shape def after it.
|
|
if (!BBVisitedInfo[&MBB].FirstAMX)
|
|
BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
|
|
} else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
|
|
// Record the call only if the callee clobbers all AMX registers.
|
|
BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
|
|
}
|
|
}
|
|
if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
|
|
if (&MBB == &MF.front())
|
|
CfgNeedInsert.insert(MIRef(&MBB));
|
|
else
|
|
CfgLiveInBBs.push_back(&MBB);
|
|
}
|
|
if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
|
|
for (auto *Succ : MBB.successors())
|
|
if (!isLoopBackEdge(Succ, &MBB))
|
|
BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
|
|
}
|
|
|
|
// Update NeedTileCfgLiveIn for predecessors.
|
|
while (!CfgLiveInBBs.empty()) {
|
|
MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
|
|
for (auto *Pred : MBB->predecessors()) {
|
|
if (BBVisitedInfo[Pred].LastCall) {
|
|
CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
|
|
} else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
|
|
BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
|
|
if (Pred == &MF.front())
|
|
CfgNeedInsert.insert(MIRef(Pred));
|
|
else
|
|
CfgLiveInBBs.push_back(Pred);
|
|
}
|
|
}
|
|
}
|
|
|
|
// There's no AMX instruction if we didn't find a tile config live in point.
|
|
if (CfgNeedInsert.empty())
|
|
return false;
|
|
|
|
// Avoid to insert ldtilecfg before any shape defs.
|
|
SmallVector<MachineBasicBlock *, 8> WorkList;
|
|
for (auto &I : ShapeBBs) {
|
|
// TODO: We can hoist shapes across BBs here.
|
|
if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
|
|
REPORT_CONFIG_FAIL
|
|
if (BBVisitedInfo[I.first].FirstAMX &&
|
|
BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
|
|
!hoistShapesInBB(I.first, I.second))
|
|
REPORT_CONFIG_FAIL
|
|
WorkList.push_back(I.first);
|
|
}
|
|
while (!WorkList.empty()) {
|
|
MachineBasicBlock *MBB = WorkList.pop_back_val();
|
|
for (auto *Pred : MBB->predecessors()) {
|
|
if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
|
|
BBVisitedInfo[Pred].TileCfgForbidden = true;
|
|
WorkList.push_back(Pred);
|
|
}
|
|
}
|
|
}
|
|
|
|
DebugLoc DL;
|
|
SmallSet<MIRef, 8> VisitedOrInserted;
|
|
int SS = MF.getFrameInfo().CreateStackObject(
|
|
ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
|
|
|
|
// Try to insert for the tile config live in points.
|
|
for (auto I : CfgNeedInsert) {
|
|
SmallSet<MIRef, 8> InsertPoints;
|
|
SmallVector<MIRef, 8> WorkList({I});
|
|
while (!WorkList.empty()) {
|
|
MIRef I = WorkList.pop_back_val();
|
|
if (!VisitedOrInserted.count(I)) {
|
|
if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
|
|
// If the BB is all shapes reachable, stop sink and try to insert.
|
|
InsertPoints.insert(I);
|
|
} else {
|
|
// Avoid the BB to be multi visited.
|
|
VisitedOrInserted.insert(I);
|
|
// Sink the inserting point along the chain with NeedTileCfgLiveIn =
|
|
// true when MBB isn't all shapes reachable.
|
|
for (auto *Succ : I.MBB->successors())
|
|
if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
|
|
WorkList.push_back(MIRef(Succ));
|
|
}
|
|
}
|
|
}
|
|
|
|
// A given point might be forked due to shape conditions are not met.
|
|
for (MIRef I : InsertPoints) {
|
|
// Make sure we insert ldtilecfg after the last shape def in MBB.
|
|
if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
|
|
I = ShapeBBs[I.MBB].back();
|
|
// There're chances the MBB is sunk more than once. Record it to avoid
|
|
// multi insert.
|
|
if (VisitedOrInserted.insert(I).second) {
|
|
auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
|
|
addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)),
|
|
SS);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Zero stack slot.
|
|
MachineBasicBlock &MBB = MF.front();
|
|
MachineInstr *MI = &*MBB.begin();
|
|
if (ST.hasAVX512()) {
|
|
Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
|
|
BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm)
|
|
.addReg(Zmm, RegState::Undef)
|
|
.addReg(Zmm, RegState::Undef);
|
|
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
|
|
.addReg(Zmm);
|
|
} else if (ST.hasAVX2()) {
|
|
Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
|
|
BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm)
|
|
.addReg(Ymm, RegState::Undef)
|
|
.addReg(Ymm, RegState::Undef);
|
|
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
|
|
.addReg(Ymm);
|
|
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
|
|
.addReg(Ymm);
|
|
} else {
|
|
assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
|
|
Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
|
|
BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm)
|
|
.addReg(Xmm, RegState::Undef)
|
|
.addReg(Xmm, RegState::Undef);
|
|
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS)
|
|
.addReg(Xmm);
|
|
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16)
|
|
.addReg(Xmm);
|
|
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32)
|
|
.addReg(Xmm);
|
|
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48)
|
|
.addReg(Xmm);
|
|
}
|
|
// Fill in the palette first.
|
|
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
|
|
|
|
return true;
|
|
}
|
|
|
|
FunctionPass *llvm::createX86PreTileConfigPass() {
|
|
return new X86PreTileConfig();
|
|
}
|