forked from OSchip/llvm-project
423 lines
17 KiB
C++
423 lines
17 KiB
C++
//===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
/// Insert tilecfg for each area of key AMX intrinsic.
|
|
/// All the key AMX intrinsic's tile operand must come from tileload. And the
|
|
/// def tile of key AMX intrinsic must be tilestored.
|
|
/// take tdpbssd for example:
|
|
/// --------------------------------------------------------------------------
|
|
/// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key
|
|
/// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) |
|
|
/// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx
|
|
/// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) |
|
|
/// call void @llvm.x86.tilestored64.internal(... td) area
|
|
/// --------------------------------------------------------------------------
|
|
/// This pass will insert tilecfg before every key-amx-area, some like:
|
|
/// --------------------------------------------------------------------------
|
|
/// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
|
|
/// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
|
|
/// ...
|
|
/// ... pre-config shape of %t1 *
|
|
/// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
|
|
/// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
|
|
/// ... *
|
|
/// ... pre-config shape of %t2 * shapes
|
|
/// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 *
|
|
/// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
|
|
/// ...
|
|
/// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
#include "X86.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/CodeGen/Passes.h"
|
|
#include "llvm/CodeGen/TargetPassConfig.h"
|
|
#include "llvm/CodeGen/ValueTypes.h"
|
|
#include "llvm/IR/DataLayout.h"
|
|
#include "llvm/IR/Function.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/IntrinsicInst.h"
|
|
#include "llvm/IR/IntrinsicsX86.h"
|
|
#include "llvm/IR/PatternMatch.h"
|
|
#include "llvm/InitializePasses.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "llvm/Target/TargetMachine.h"
|
|
|
|
using namespace llvm;
|
|
using namespace PatternMatch;
|
|
|
|
#define DEBUG_TYPE "pre-amx-config"
|
|
|
|
static bool isAMXIntrinsic(IntrinsicInst *II) {
|
|
for (Value *Operand : II->operands())
|
|
if (Operand->getType()->isX86_AMXTy())
|
|
return true;
|
|
return II->getType()->isX86_AMXTy();
|
|
}
|
|
|
|
static bool isTileLoad(IntrinsicInst *II) {
|
|
return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal;
|
|
}
|
|
|
|
static bool isTileStore(IntrinsicInst *II) {
|
|
return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
|
|
}
|
|
|
|
#ifndef NDEBUG
|
|
static bool onlyTileDef(IntrinsicInst *II) {
|
|
for (Value *Operand : II->operands())
|
|
if (Operand->getType()->isX86_AMXTy())
|
|
return false;
|
|
return II->getType()->isX86_AMXTy();
|
|
}
|
|
|
|
static bool brokenVolatile(Instruction *I) {
|
|
// Todo: it is weak to identify a normal call here.
|
|
if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
|
|
return true;
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
namespace {
|
|
class X86PreAMXConfig {
|
|
Function &F;
|
|
|
|
public:
|
|
X86PreAMXConfig(Function &Func) : F(Func) {}
|
|
bool preTileConfig();
|
|
bool addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
|
|
bool findConfigShapes(
|
|
DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes);
|
|
bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
|
|
bool preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
|
|
SmallVector<Value *, 8> &Shapes);
|
|
BasicBlock::iterator
|
|
getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
|
|
SmallVector<Value *, 8> &Shapes);
|
|
bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
|
|
IntrinsicInst *KeyAMX);
|
|
};
|
|
|
|
// Orderly write the shapes in tilecfg's mem. This maybe not right.
|
|
// Because the first shape may not corresponding to the first tmm register,
|
|
// so we need to handle at at X86FastTileConfig::materializeTileCfg()
|
|
// after register allocation.
|
|
// For example:
|
|
// --------------------------------------------------------------------------
|
|
// zeroinitialize tilecfg's mem (of ldtilecfg)
|
|
// --------------------------------------------------------------------------
|
|
// ... pre-config shape of %t1 *
|
|
// %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 *
|
|
// %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
|
|
// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
|
|
// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
|
|
// ... *
|
|
// ... pre-config shape of %t2 *
|
|
// %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 *
|
|
// %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
|
|
// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
|
|
// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
|
|
// ... *
|
|
// ... pre-config shape of %t3 * of
|
|
// %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 *
|
|
// %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
|
|
// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
|
|
// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
|
|
// ... * tiles
|
|
// ... pre-config shape of %td *
|
|
// %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 *
|
|
// %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
|
|
// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
|
|
// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
|
|
// --------------------------------------------------------------------------
|
|
// call void @llvm.x86.ldtilecfg(i8* %mem) * tile config
|
|
// --------------------------------------------------------------------------
|
|
// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
|
|
// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
|
|
// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
|
|
// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
|
|
// call void @llvm.x86.tilestored64.internal(... td) area
|
|
// --------------------------------------------------------------------------
|
|
bool X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
|
|
SmallVector<Value *, 8> &Shapes) {
|
|
bool Write = false;
|
|
LLVMContext &Ctx = Pos->getParent()->getContext();
|
|
Type *I8Ty = Type::getInt8Ty(Ctx);
|
|
Type *I16Ty = Type::getInt16Ty(Ctx);
|
|
|
|
// TODO: Currently we defaultly set Palette = 1, it may be assigned to
|
|
// other value in the future.
|
|
Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
|
|
Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
|
|
Value *PalettePos =
|
|
GetElementPtrInst::Create(I8Ty, I8Ptr, PaletteOffset, "", Pos);
|
|
new StoreInst(PaletteValue, PalettePos, Pos);
|
|
|
|
for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
|
|
Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
|
|
Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
|
|
const std::string ShapeName = "amx.tmm." + itostr(I);
|
|
Value *RowPos = GetElementPtrInst::Create(I8Ty, I8Ptr, RowOffset,
|
|
ShapeName + ".shape.row", Pos);
|
|
Value *ColPos = GetElementPtrInst::Create(I8Ty, I8Ptr, ColOffset, "", Pos);
|
|
ColPos = new BitCastInst(ColPos, PointerType::get(I16Ty, 0),
|
|
ShapeName + ".shape.col", Pos);
|
|
Value *Row = Shapes[I * 2];
|
|
Value *Col = Shapes[I * 2 + 1];
|
|
Row = new TruncInst(Row, I8Ty, "", Pos);
|
|
new StoreInst(Row, RowPos, Pos);
|
|
new StoreInst(Col, ColPos, Pos);
|
|
Write = true;
|
|
}
|
|
return Write;
|
|
}
|
|
|
|
bool X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
|
|
SmallVector<Value *, 8> &Shapes) {
|
|
Module *M = F.getParent();
|
|
IRBuilder<> Builder(ModelStart);
|
|
const DataLayout &DL = M->getDataLayout();
|
|
unsigned AddrSpace = DL.getAllocaAddrSpace();
|
|
LLVMContext &Ctx = Builder.getContext();
|
|
Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
|
|
Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
|
|
|
|
AllocaInst *Addr =
|
|
new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
|
|
Addr->setAlignment(Alignment);
|
|
Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
|
|
|
|
std::array<Value *, 1> Args = {I8Ptr};
|
|
Instruction *Cfg =
|
|
Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, Args);
|
|
|
|
Value *Val0 = Constant::getNullValue(V512Ty);
|
|
Instruction *Init0 = new StoreInst(Val0, Addr, false, Alignment, Cfg);
|
|
assert(Init0 && "Not Zero initilizate the cfg mem!");
|
|
|
|
preWriteTileCfg(I8Ptr, Cfg, Shapes);
|
|
|
|
return Init0;
|
|
}
|
|
|
|
// Todo: We may need to handle "more than one store" case in the future.
|
|
bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
|
|
IntrinsicInst *Store,
|
|
IntrinsicInst *KeyAMX) {
|
|
Value *ST = Store->getOperand(4);
|
|
|
|
// Only has tileload and tilestore.
|
|
if (!KeyAMX)
|
|
return (Loads.size() == 1) && Loads.contains(ST);
|
|
|
|
// All Loads should be operands of KeyAMX.
|
|
// All tile operands of KeyAMX should come from Loads.
|
|
for (Value *Op : KeyAMX->operands()) {
|
|
if (Op->getType()->isX86_AMXTy())
|
|
if (!Loads.erase(Op))
|
|
return false;
|
|
}
|
|
|
|
// The def of KeyAMX should be stored into mem.
|
|
// Todo: is it key amx can be no def?
|
|
return Loads.empty() && (ST == cast<Value>(KeyAMX));
|
|
}
|
|
|
|
bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
|
|
SmallVector<Value *, 8> &Shapes) {
|
|
for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
|
|
Value *Op = KeyAMX->getOperand(I);
|
|
if (!Op->getType()->isX86_AMXTy())
|
|
continue;
|
|
IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
|
|
assert((TileDef && isTileLoad(TileDef)) &&
|
|
"All KeyAMX's tile definiation should comes from TileLoad!");
|
|
Shapes.push_back(TileDef->getOperand(0));
|
|
Shapes.push_back(TileDef->getOperand(1));
|
|
}
|
|
if (!isTileStore(KeyAMX)) {
|
|
Shapes.push_back(KeyAMX->getOperand(0));
|
|
Shapes.push_back(KeyAMX->getOperand(1));
|
|
}
|
|
return Shapes.size() != 0;
|
|
}
|
|
|
|
// Collect the shapes and skip the area of current key amx intrinsic.
|
|
//
|
|
// For example:
|
|
// ...
|
|
// --------------------------------------------------------------------------
|
|
// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k)
|
|
// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k)
|
|
// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k)
|
|
// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
|
|
// call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
|
|
// --------------------------------------------------------------------------
|
|
BasicBlock::iterator
|
|
X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
|
|
SmallVector<Value *, 8> &Shapes) {
|
|
IntrinsicInst *KeyAMX = nullptr;
|
|
BasicBlock *BB = Iter->getParent();
|
|
BasicBlock::iterator PosEnd = BB->end();
|
|
SmallSet<Value *, 4> Loads;
|
|
|
|
// See TileStore as "Config Position End" and check volatile model.
|
|
for (auto I = Iter, E = BB->end(); I != E; ++I) {
|
|
assert(!brokenVolatile(&*I) && "Not reach tile store!");
|
|
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
|
|
if (!II || !isAMXIntrinsic(II))
|
|
continue;
|
|
|
|
if (isTileLoad(II)) {
|
|
Loads.insert(II);
|
|
} else if (isTileStore(II)) {
|
|
if (!checkVolatileModel(Loads, II, KeyAMX))
|
|
report_fatal_error("Not Volatile AMX Model!");
|
|
PosEnd = I;
|
|
break;
|
|
} else {
|
|
assert(!KeyAMX && "Too many key amx intrinsic!");
|
|
KeyAMX = II;
|
|
}
|
|
}
|
|
assert(PosEnd != BB->end() && "Not find TileStore!");
|
|
|
|
// See KeyAMX as TileStore if only TileLoad and TileStore.
|
|
if (!KeyAMX)
|
|
KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
|
|
|
|
// Get Shapes in order.
|
|
assert(Shapes.empty() && "Shapes should be clean.");
|
|
getKeyAMXShapes(KeyAMX, Shapes);
|
|
|
|
return PosEnd;
|
|
}
|
|
|
|
// Record a key amx area's shapes with its position.
|
|
// Use the first tileload as its position.
|
|
// For example:
|
|
// ...
|
|
// --------------------------------------------------------------------------
|
|
// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos
|
|
// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) /
|
|
// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes:
|
|
// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n)
|
|
// call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n)
|
|
// --------------------------------------------------------------------------
|
|
bool X86PreAMXConfig::findConfigShapes(
|
|
DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes) {
|
|
bool Find = false;
|
|
for (BasicBlock &BB : F) {
|
|
for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
|
|
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
|
|
if (!II)
|
|
continue;
|
|
if (!isAMXIntrinsic(II))
|
|
continue;
|
|
assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
|
|
|
|
I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
|
|
Find = true;
|
|
}
|
|
}
|
|
return Find;
|
|
}
|
|
|
|
// Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
|
|
// e.g. (key amx = tdpbssd)
|
|
// --------------------------------------------------------------------------
|
|
// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
|
|
// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
|
|
// ...
|
|
// ... pre-config shape of %t1 *
|
|
// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
|
|
// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
|
|
// ... *
|
|
// ... pre-config shape of %t2 *
|
|
// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
|
|
// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
|
|
// ... *
|
|
// ... pre-config shape of %t3 * of
|
|
// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
|
|
// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
|
|
// ... * tiles
|
|
// ... pre-config shape of %td *
|
|
// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
|
|
// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
|
|
//
|
|
// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config
|
|
// --------------------------------------------------------------------------
|
|
// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
|
|
// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
|
|
// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
|
|
// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
|
|
// call void @llvm.x86.tilestored64.internal(... td) area
|
|
// --------------------------------------------------------------------------
|
|
bool X86PreAMXConfig::preTileConfig() {
|
|
DenseMap<Instruction *, SmallVector<Value *, 8>> PosAndShapes;
|
|
bool NeedCfg = findConfigShapes(PosAndShapes);
|
|
if (!NeedCfg)
|
|
return false;
|
|
for (auto &IPAndShapes : PosAndShapes)
|
|
addTileConfig(IPAndShapes.first, IPAndShapes.second);
|
|
|
|
return true;
|
|
}
|
|
} // anonymous namespace
|
|
|
|
namespace {
|
|
|
|
class X86PreAMXConfigPass : public FunctionPass {
|
|
public:
|
|
static char ID;
|
|
|
|
X86PreAMXConfigPass() : FunctionPass(ID) {
|
|
initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
bool runOnFunction(Function &F) override {
|
|
TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
|
|
bool C = false;
|
|
|
|
// Prepare for fast register allocation at O0.
|
|
if (TM->getOptLevel() == CodeGenOpt::None) {
|
|
|
|
// We pre-config each key AMX intrinsic at O0.
|
|
// In theory, one tile config can cover several AMX intrinsics, but
|
|
// it is very diffcult to classify the tile shapes at O0. So here we
|
|
// let thing be easy, pre-config every key AMX intrinsic.
|
|
X86PreAMXConfig PCFG(F);
|
|
C = PCFG.preTileConfig();
|
|
}
|
|
|
|
return C;
|
|
}
|
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
AU.setPreservesCFG();
|
|
AU.addRequired<TargetPassConfig>();
|
|
}
|
|
};
|
|
|
|
} // anonymous namespace
|
|
|
|
static const char PassName[] = "Pre AMX Tile Config";
|
|
char X86PreAMXConfigPass::ID = 0;
|
|
INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
|
|
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
|
|
INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
|
|
|
|
FunctionPass *llvm::createX86PreAMXConfigPass() {
|
|
return new X86PreAMXConfigPass();
|
|
}
|