forked from OSchip/llvm-project
680 lines
28 KiB
C++
680 lines
28 KiB
C++
//===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
/// \file Pass to transform amx intrinsics to scalar operations.
|
|
/// This pass is always enabled and it skips when it is not -O0 and has no
|
|
/// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
|
|
/// intrinsics is near the amx intrinsics code. We are not able to find a
|
|
/// point which post-dominate all the shape and dominate all amx intrinsics.
|
|
/// To decouple the dependency of the shape, we transform amx intrinsics
|
|
/// to scalar operation, so that compiling doesn't fail. In long term, we
|
|
/// should improve fast register allocation to allocate amx register.
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
#include "X86.h"
|
|
#include "llvm/ADT/DenseSet.h"
|
|
#include "llvm/ADT/PostOrderIterator.h"
|
|
#include "llvm/Analysis/DomTreeUpdater.h"
|
|
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/CodeGen/Passes.h"
|
|
#include "llvm/CodeGen/TargetPassConfig.h"
|
|
#include "llvm/CodeGen/ValueTypes.h"
|
|
#include "llvm/IR/DataLayout.h"
|
|
#include "llvm/IR/Function.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/IntrinsicInst.h"
|
|
#include "llvm/IR/IntrinsicsX86.h"
|
|
#include "llvm/IR/PatternMatch.h"
|
|
#include "llvm/InitializePasses.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Target/TargetMachine.h"
|
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
|
#include "llvm/Transforms/Utils/LoopUtils.h"
|
|
|
|
using namespace llvm;
|
|
using namespace PatternMatch;
|
|
|
|
#define DEBUG_TYPE "lower-amx-intrinsics"
|
|
|
|
#ifndef NDEBUG
|
|
static bool isV256I32Ty(Type *Ty) {
|
|
if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
|
|
return FVT->getNumElements() == 256 &&
|
|
FVT->getElementType()->isIntegerTy(32);
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
static cl::opt<bool>
|
|
X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
|
|
cl::desc("X86: enable AMX scalarizition."));
|
|
|
|
namespace {
|
|
class X86LowerAMXIntrinsics {
|
|
Function &Func;
|
|
|
|
public:
|
|
X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
|
|
: Func(F), DTU(DomTU), LI(LoopI) {}
|
|
bool visit();
|
|
|
|
private:
|
|
DomTreeUpdater &DTU;
|
|
LoopInfo *LI;
|
|
BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
|
|
Value *Step, StringRef Name, IRBuilderBase &B,
|
|
Loop *L);
|
|
template <bool IsTileLoad>
|
|
Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
|
|
IRBuilderBase &B, Value *Row, Value *Col,
|
|
Value *Ptr, Value *Stride, Value *Tile);
|
|
template <Intrinsic::ID IntrID>
|
|
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
|
|
IntrID == Intrinsic::x86_tdpbsud_internal ||
|
|
IntrID == Intrinsic::x86_tdpbusd_internal ||
|
|
IntrID == Intrinsic::x86_tdpbuud_internal ||
|
|
IntrID == Intrinsic::x86_tdpbf16ps_internal,
|
|
Value *>::type
|
|
createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
|
|
Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
|
|
Value *RHS);
|
|
template <bool IsTileLoad>
|
|
bool lowerTileLoadStore(Instruction *TileLoadStore);
|
|
template <Intrinsic::ID IntrID>
|
|
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
|
|
IntrID == Intrinsic::x86_tdpbsud_internal ||
|
|
IntrID == Intrinsic::x86_tdpbusd_internal ||
|
|
IntrID == Intrinsic::x86_tdpbuud_internal ||
|
|
IntrID == Intrinsic::x86_tdpbf16ps_internal,
|
|
bool>::type
|
|
lowerTileDP(Instruction *TileDP);
|
|
bool lowerTileZero(Instruction *TileZero);
|
|
};
|
|
} // anonymous namespace
|
|
|
|
BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
|
|
BasicBlock *Exit, Value *Bound,
|
|
Value *Step, StringRef Name,
|
|
IRBuilderBase &B, Loop *L) {
|
|
LLVMContext &Ctx = Preheader->getContext();
|
|
BasicBlock *Header =
|
|
BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
|
|
BasicBlock *Body =
|
|
BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
|
|
BasicBlock *Latch =
|
|
BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
|
|
|
|
Type *I16Ty = Type::getInt16Ty(Ctx);
|
|
BranchInst::Create(Body, Header);
|
|
BranchInst::Create(Latch, Body);
|
|
PHINode *IV =
|
|
PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
|
|
IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
|
|
|
|
B.SetInsertPoint(Latch);
|
|
Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
|
|
Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
|
|
BranchInst::Create(Header, Exit, Cond, Latch);
|
|
IV->addIncoming(Inc, Latch);
|
|
|
|
BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
|
|
BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
|
|
PreheaderBr->setSuccessor(0, Header);
|
|
DTU.applyUpdatesPermissive({
|
|
{DominatorTree::Delete, Preheader, Tmp},
|
|
{DominatorTree::Insert, Header, Body},
|
|
{DominatorTree::Insert, Body, Latch},
|
|
{DominatorTree::Insert, Latch, Header},
|
|
{DominatorTree::Insert, Latch, Exit},
|
|
{DominatorTree::Insert, Preheader, Header},
|
|
});
|
|
if (LI) {
|
|
L->addBasicBlockToLoop(Header, *LI);
|
|
L->addBasicBlockToLoop(Body, *LI);
|
|
L->addBasicBlockToLoop(Latch, *LI);
|
|
}
|
|
return Body;
|
|
}
|
|
|
|
template <bool IsTileLoad>
|
|
Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
|
|
BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
|
|
Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
|
|
std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
|
|
Loop *RowLoop = nullptr;
|
|
Loop *ColLoop = nullptr;
|
|
if (LI) {
|
|
RowLoop = LI->AllocateLoop();
|
|
ColLoop = LI->AllocateLoop();
|
|
RowLoop->addChildLoop(ColLoop);
|
|
if (Loop *ParentL = LI->getLoopFor(Start))
|
|
ParentL->addChildLoop(RowLoop);
|
|
else
|
|
LI->addTopLevelLoop(RowLoop);
|
|
}
|
|
|
|
BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
|
|
IntrinName + ".scalarize.rows", B, RowLoop);
|
|
BasicBlock *RowLatch = RowBody->getSingleSuccessor();
|
|
|
|
BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
|
|
IntrinName + ".scalarize.cols", B, ColLoop);
|
|
|
|
BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
|
|
BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
|
|
BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
|
|
Value *CurrentRow = &*RowLoopHeader->begin();
|
|
Value *CurrentCol = &*ColLoopHeader->begin();
|
|
Type *EltTy = B.getInt32Ty();
|
|
FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
|
|
|
|
// Common part for tileload and tilestore
|
|
// *.scalarize.cols.body:
|
|
// Calculate %idxmem and %idxvec
|
|
B.SetInsertPoint(ColBody->getTerminator());
|
|
Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
|
|
Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
|
|
Value *Offset =
|
|
B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
|
|
unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
|
|
Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
|
|
Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
|
|
Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
|
|
if (IsTileLoad) {
|
|
// tileload.scalarize.rows.header:
|
|
// %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
|
|
// %tileload.scalarize.rows.latch ]
|
|
B.SetInsertPoint(RowLoopHeader->getTerminator());
|
|
Value *VecZero = Constant::getNullValue(V256I32Ty);
|
|
PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
|
|
VecCPhiRowLoop->addIncoming(VecZero, Start);
|
|
|
|
// tileload.scalarize.cols.header:
|
|
// %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
|
|
// ], [ %ResVec, %tileload.scalarize.cols.latch ]
|
|
B.SetInsertPoint(ColLoopHeader->getTerminator());
|
|
PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
|
|
VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
|
|
|
|
// tileload.scalarize.cols.body:
|
|
// Calculate %idxmem and %idxvec
|
|
// %eltptr = getelementptr i32, i32* %base, i64 %idxmem
|
|
// %elt = load i32, i32* %ptr
|
|
// %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
|
|
B.SetInsertPoint(ColBody->getTerminator());
|
|
Value *Elt = B.CreateLoad(EltTy, EltPtr);
|
|
Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
|
|
VecPhi->addIncoming(ResVec, ColLoopLatch);
|
|
VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
|
|
|
|
return ResVec;
|
|
} else {
|
|
auto *BitCast = cast<BitCastInst>(Tile);
|
|
Value *Vec = BitCast->getOperand(0);
|
|
assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
|
|
// tilestore.scalarize.cols.body:
|
|
// %mul = mul i16 %row.iv, i16 16
|
|
// %idx = add i16 %mul, i16 %col.iv
|
|
// %vec = extractelement <16 x i32> %vec, i16 %idx
|
|
// store i32 %vec, i32* %ptr
|
|
B.SetInsertPoint(ColBody->getTerminator());
|
|
Value *Elt = B.CreateExtractElement(Vec, Idx);
|
|
|
|
B.CreateStore(Elt, EltPtr);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
template <Intrinsic::ID IntrID>
|
|
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
|
|
IntrID == Intrinsic::x86_tdpbsud_internal ||
|
|
IntrID == Intrinsic::x86_tdpbusd_internal ||
|
|
IntrID == Intrinsic::x86_tdpbuud_internal ||
|
|
IntrID == Intrinsic::x86_tdpbf16ps_internal,
|
|
Value *>::type
|
|
X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
|
|
IRBuilderBase &B, Value *Row,
|
|
Value *Col, Value *K, Value *Acc,
|
|
Value *LHS, Value *RHS) {
|
|
std::string IntrinName;
|
|
switch (IntrID) {
|
|
case Intrinsic::x86_tdpbssd_internal:
|
|
IntrinName = "tiledpbssd";
|
|
break;
|
|
case Intrinsic::x86_tdpbsud_internal:
|
|
IntrinName = "tiledpbsud";
|
|
break;
|
|
case Intrinsic::x86_tdpbusd_internal:
|
|
IntrinName = "tiledpbusd";
|
|
break;
|
|
case Intrinsic::x86_tdpbuud_internal:
|
|
IntrinName = "tiledpbuud";
|
|
break;
|
|
case Intrinsic::x86_tdpbf16ps_internal:
|
|
IntrinName = "tiledpbf16ps";
|
|
break;
|
|
}
|
|
Loop *RowLoop = nullptr;
|
|
Loop *ColLoop = nullptr;
|
|
Loop *InnerLoop = nullptr;
|
|
if (LI) {
|
|
RowLoop = LI->AllocateLoop();
|
|
ColLoop = LI->AllocateLoop();
|
|
InnerLoop = LI->AllocateLoop();
|
|
ColLoop->addChildLoop(InnerLoop);
|
|
RowLoop->addChildLoop(ColLoop);
|
|
if (Loop *ParentL = LI->getLoopFor(Start))
|
|
ParentL->addChildLoop(RowLoop);
|
|
else
|
|
LI->addTopLevelLoop(RowLoop);
|
|
}
|
|
|
|
BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
|
|
IntrinName + ".scalarize.rows", B, RowLoop);
|
|
BasicBlock *RowLatch = RowBody->getSingleSuccessor();
|
|
|
|
BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
|
|
IntrinName + ".scalarize.cols", B, ColLoop);
|
|
|
|
BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
|
|
|
|
B.SetInsertPoint(ColBody->getTerminator());
|
|
BasicBlock *InnerBody =
|
|
createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
|
|
IntrinName + ".scalarize.inner", B, InnerLoop);
|
|
|
|
BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
|
|
BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
|
|
BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
|
|
BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
|
|
Value *CurrentRow = &*RowLoopHeader->begin();
|
|
Value *CurrentCol = &*ColLoopHeader->begin();
|
|
Value *CurrentInner = &*InnerLoopHeader->begin();
|
|
|
|
FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
|
|
auto *BitCastAcc = cast<BitCastInst>(Acc);
|
|
Value *VecC = BitCastAcc->getOperand(0);
|
|
assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
|
|
// TODO else create BitCast from x86amx to v256i32.
|
|
// Store x86amx to memory, and reload from memory
|
|
// to vector. However with -O0, it doesn't happen.
|
|
auto *BitCastLHS = cast<BitCastInst>(LHS);
|
|
Value *VecA = BitCastLHS->getOperand(0);
|
|
assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
|
|
auto *BitCastRHS = cast<BitCastInst>(RHS);
|
|
Value *VecB = BitCastRHS->getOperand(0);
|
|
assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
|
|
|
|
// tiledpbssd.scalarize.rows.header:
|
|
// %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
|
|
// %tiledpbssd.scalarize.rows.latch ]
|
|
|
|
// %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
|
|
// %NewVecD, %tiledpbssd.scalarize.rows.latch ]
|
|
B.SetInsertPoint(RowLoopHeader->getTerminator());
|
|
PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
|
|
VecCPhiRowLoop->addIncoming(VecC, Start);
|
|
Value *VecZero = Constant::getNullValue(V256I32Ty);
|
|
PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
|
|
VecDPhiRowLoop->addIncoming(VecZero, Start);
|
|
|
|
// tiledpbssd.scalarize.cols.header:
|
|
// %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
|
|
// %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
|
|
// %tiledpbssd.scalarize.cols.latch ]
|
|
|
|
// %vec.d.phi.col = phi <256 x i32> [
|
|
// %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
|
|
// %tiledpbssd.scalarize.cols.latch ]
|
|
|
|
// calculate idxc.
|
|
B.SetInsertPoint(ColLoopHeader->getTerminator());
|
|
PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
|
|
VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
|
|
PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
|
|
VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
|
|
Value *IdxC =
|
|
B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
|
|
|
|
// tiledpbssd.scalarize.inner.header:
|
|
// %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
|
|
// %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
|
|
// %tiledpbssd.scalarize.inner.latch ]
|
|
|
|
B.SetInsertPoint(InnerLoopHeader->getTerminator());
|
|
PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
|
|
VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
|
|
|
|
B.SetInsertPoint(InnerBody->getTerminator());
|
|
Value *IdxA =
|
|
B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
|
|
Value *IdxB =
|
|
B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
|
|
Value *NewVecC = nullptr;
|
|
|
|
if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
|
|
// tiledpbssd.scalarize.inner.body:
|
|
// calculate idxa, idxb
|
|
// %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
|
|
// %elta = extractelement <256 x i32> %veca, i16 %idxa
|
|
// %eltav4i8 = bitcast i32 %elta to <4 x i8>
|
|
// %eltb = extractelement <256 x i32> %vecb, i16 %idxb
|
|
// %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
|
|
// %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
|
|
// %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
|
|
// %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
|
|
// %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
|
|
// %neweltc = add i32 %elt, %acc
|
|
// %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
|
|
// i16 %idxc
|
|
FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
|
|
FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
|
|
Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
|
|
Value *EltA = B.CreateExtractElement(VecA, IdxA);
|
|
Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
|
|
Value *EltB = B.CreateExtractElement(VecB, IdxB);
|
|
Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
|
|
Value *SEXTSubVecB = nullptr;
|
|
Value *SEXTSubVecA = nullptr;
|
|
switch (IntrID) {
|
|
case Intrinsic::x86_tdpbssd_internal:
|
|
SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
|
|
SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
|
|
break;
|
|
case Intrinsic::x86_tdpbsud_internal:
|
|
SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
|
|
SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
|
|
break;
|
|
case Intrinsic::x86_tdpbusd_internal:
|
|
SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
|
|
SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
|
|
break;
|
|
case Intrinsic::x86_tdpbuud_internal:
|
|
SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
|
|
SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
|
|
break;
|
|
default:
|
|
llvm_unreachable("Invalid intrinsic ID!");
|
|
}
|
|
Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
|
|
Value *ResElt = B.CreateAdd(EltC, SubVecR);
|
|
NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
|
|
} else {
|
|
// tiledpbf16ps.scalarize.inner.body:
|
|
// calculate idxa, idxb, idxc
|
|
// %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
|
|
// %eltcf32 = bitcast i32 %eltc to float
|
|
// %elta = extractelement <256 x i32> %veca, i16 %idxa
|
|
// %eltav2i16 = bitcast i32 %elta to <2 x i16>
|
|
// %eltb = extractelement <256 x i32> %vecb, i16 %idxb
|
|
// %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
|
|
// %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
|
|
// x i32> <i32 2, i32 0, i32 3, i32 1>
|
|
// %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
|
|
// %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
|
|
// i32> <i32 2, i32 0, i32 3, i32 1>
|
|
// %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
|
|
// %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
|
|
// %acc = call float
|
|
// @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
|
|
// %neweltc = bitcast float %acc to i32
|
|
// %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
|
|
// i16 %idxc
|
|
// %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
|
|
// i16 %idxc
|
|
FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
|
|
FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
|
|
Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
|
|
Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
|
|
Value *EltA = B.CreateExtractElement(VecA, IdxA);
|
|
Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
|
|
Value *EltB = B.CreateExtractElement(VecB, IdxB);
|
|
Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
|
|
Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
|
|
int ShuffleMask[4] = {2, 0, 3, 1};
|
|
auto ShuffleArray = makeArrayRef(ShuffleMask);
|
|
Value *AV2F32 = B.CreateBitCast(
|
|
B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
|
|
Value *BV2F32 = B.CreateBitCast(
|
|
B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
|
|
Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
|
|
Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
|
|
NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
|
|
}
|
|
|
|
// tiledpbssd.scalarize.cols.latch:
|
|
// %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
|
|
// %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
|
|
// i16 %idxc
|
|
B.SetInsertPoint(ColLoopLatch->getTerminator());
|
|
Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
|
|
Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
|
|
|
|
VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
|
|
VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
|
|
VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
|
|
VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
|
|
VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
|
|
|
|
return NewVecD;
|
|
}
|
|
|
|
template <Intrinsic::ID IntrID>
|
|
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
|
|
IntrID == Intrinsic::x86_tdpbsud_internal ||
|
|
IntrID == Intrinsic::x86_tdpbusd_internal ||
|
|
IntrID == Intrinsic::x86_tdpbuud_internal ||
|
|
IntrID == Intrinsic::x86_tdpbf16ps_internal,
|
|
bool>::type
|
|
X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
|
|
Value *M, *N, *K, *C, *A, *B;
|
|
match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
|
|
m_Value(C), m_Value(A), m_Value(B)));
|
|
Instruction *InsertI = TileDP;
|
|
IRBuilder<> PreBuilder(TileDP);
|
|
PreBuilder.SetInsertPoint(TileDP);
|
|
// We visit the loop with (m, n/4, k/4):
|
|
// %n_dword = lshr i16 %n, 2
|
|
// %k_dword = lshr i16 %k, 2
|
|
Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
|
|
Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
|
|
BasicBlock *Start = InsertI->getParent();
|
|
BasicBlock *End =
|
|
SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
|
|
IRBuilder<> Builder(TileDP);
|
|
Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
|
|
KDWord, C, A, B);
|
|
// we cannot assume there always be bitcast after tiledpbssd. So we need to
|
|
// insert one bitcast as required
|
|
Builder.SetInsertPoint(End->getFirstNonPHI());
|
|
Value *ResAMX =
|
|
Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
|
|
// Delete TileDP intrinsic and do some clean-up.
|
|
for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
|
|
Instruction *I = cast<Instruction>((UI++)->getUser());
|
|
Value *Vec;
|
|
if (match(I, m_BitCast(m_Value(Vec)))) {
|
|
I->replaceAllUsesWith(ResVec);
|
|
I->eraseFromParent();
|
|
}
|
|
}
|
|
TileDP->replaceAllUsesWith(ResAMX);
|
|
TileDP->eraseFromParent();
|
|
return true;
|
|
}
|
|
|
|
template <bool IsTileLoad>
|
|
bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
|
|
Value *M, *N, *Ptr, *Stride, *Tile;
|
|
if (IsTileLoad)
|
|
match(TileLoadStore,
|
|
m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
|
|
m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
|
|
else
|
|
match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
|
|
m_Value(M), m_Value(N), m_Value(Ptr),
|
|
m_Value(Stride), m_Value(Tile)));
|
|
|
|
Instruction *InsertI = TileLoadStore;
|
|
IRBuilder<> PreBuilder(TileLoadStore);
|
|
PreBuilder.SetInsertPoint(TileLoadStore);
|
|
Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
|
|
Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
|
|
BasicBlock *Start = InsertI->getParent();
|
|
BasicBlock *End =
|
|
SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
|
|
IRBuilder<> Builder(TileLoadStore);
|
|
Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
|
|
Start, End, Builder, M, NDWord, Ptr, StrideDWord,
|
|
IsTileLoad ? nullptr : Tile);
|
|
if (IsTileLoad) {
|
|
// we cannot assume there always be bitcast after tileload. So we need to
|
|
// insert one bitcast as required
|
|
Builder.SetInsertPoint(End->getFirstNonPHI());
|
|
Value *ResAMX =
|
|
Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
|
|
// Delete tileloadd6 intrinsic and do some clean-up
|
|
for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end();
|
|
UI != UE;) {
|
|
Instruction *I = cast<Instruction>((UI++)->getUser());
|
|
Value *Vec;
|
|
if (match(I, m_BitCast(m_Value(Vec)))) {
|
|
I->replaceAllUsesWith(ResVec);
|
|
I->eraseFromParent();
|
|
}
|
|
}
|
|
TileLoadStore->replaceAllUsesWith(ResAMX);
|
|
}
|
|
TileLoadStore->eraseFromParent();
|
|
return true;
|
|
}
|
|
|
|
bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
|
|
IRBuilder<> Builder(TileZero);
|
|
FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
|
|
Value *VecZero = Constant::getNullValue(V256I32Ty);
|
|
for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
|
|
Instruction *I = cast<Instruction>((UI++)->getUser());
|
|
Value *Vec;
|
|
if (match(I, m_BitCast(m_Value(Vec)))) {
|
|
I->replaceAllUsesWith(VecZero);
|
|
I->eraseFromParent();
|
|
}
|
|
}
|
|
TileZero->eraseFromParent();
|
|
return true;
|
|
}
|
|
|
|
bool X86LowerAMXIntrinsics::visit() {
|
|
bool C = false;
|
|
SmallVector<IntrinsicInst *, 8> WorkList;
|
|
for (BasicBlock *BB : depth_first(&Func)) {
|
|
for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
|
|
if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
|
|
switch (Inst->getIntrinsicID()) {
|
|
case Intrinsic::x86_tdpbssd_internal:
|
|
case Intrinsic::x86_tdpbsud_internal:
|
|
case Intrinsic::x86_tdpbusd_internal:
|
|
case Intrinsic::x86_tdpbuud_internal:
|
|
case Intrinsic::x86_tileloadd64_internal:
|
|
case Intrinsic::x86_tilestored64_internal:
|
|
case Intrinsic::x86_tilezero_internal:
|
|
case Intrinsic::x86_tdpbf16ps_internal:
|
|
WorkList.push_back(Inst);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto *Inst : WorkList) {
|
|
switch (Inst->getIntrinsicID()) {
|
|
case Intrinsic::x86_tdpbssd_internal:
|
|
C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
|
|
break;
|
|
case Intrinsic::x86_tdpbsud_internal:
|
|
C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
|
|
break;
|
|
case Intrinsic::x86_tdpbusd_internal:
|
|
C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
|
|
break;
|
|
case Intrinsic::x86_tdpbuud_internal:
|
|
C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
|
|
break;
|
|
case Intrinsic::x86_tdpbf16ps_internal:
|
|
C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
|
|
break;
|
|
case Intrinsic::x86_tileloadd64_internal:
|
|
C = lowerTileLoadStore<true>(Inst) || C;
|
|
break;
|
|
case Intrinsic::x86_tilestored64_internal:
|
|
C = lowerTileLoadStore<false>(Inst) || C;
|
|
break;
|
|
case Intrinsic::x86_tilezero_internal:
|
|
C = lowerTileZero(Inst) || C;
|
|
break;
|
|
default:
|
|
llvm_unreachable("invalid amx intrinsics!");
|
|
}
|
|
}
|
|
|
|
return C;
|
|
}
|
|
|
|
class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
|
|
public:
|
|
static char ID;
|
|
|
|
X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
|
|
initializeX86LowerAMXIntrinsicsLegacyPassPass(
|
|
*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
bool runOnFunction(Function &F) override {
|
|
if (!X86ScalarizeAMX)
|
|
return false;
|
|
TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
|
|
if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
|
|
TM->getOptLevel() != CodeGenOpt::None)
|
|
return false;
|
|
|
|
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
|
|
auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
|
|
auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
|
|
auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
|
|
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
|
|
|
|
X86LowerAMXIntrinsics LAT(F, DTU, LI);
|
|
return LAT.visit();
|
|
}
|
|
StringRef getPassName() const override { return "Lower AMX intrinsics"; }
|
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
AU.addPreserved<DominatorTreeWrapperPass>();
|
|
AU.addPreserved<LoopInfoWrapperPass>();
|
|
AU.addRequired<TargetPassConfig>();
|
|
}
|
|
};
|
|
|
|
static const char PassName[] = "Lower AMX intrinsics";
|
|
char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
|
|
INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
|
|
false, false)
|
|
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
|
|
INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
|
|
false, false)
|
|
|
|
FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
|
|
return new X86LowerAMXIntrinsicsLegacyPass();
|
|
}
|