forked from OSchip/llvm-project
[Matrix] Hoist load/store generation logic, add helpers for tiled access.
This patch slightly generalizes the code to emit loads and stores of a matrix and adds helpers to load/store a tile of a larger matrix. This will be used in a follow-up patch introducing initial tiling. Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke Reviewed By: anemet Differential Revision: https://reviews.llvm.org/D75564
This commit is contained in:
parent
c2586cab89
commit
0cc2d23751
|
@ -181,8 +181,8 @@ class LowerMatrixIntrinsics {
|
|||
|
||||
void setColumn(unsigned i, Value *V) { Columns[i] = V; }
|
||||
|
||||
size_t getNumColumns() const { return Columns.size(); }
|
||||
size_t getNumRows() const {
|
||||
unsigned getNumColumns() const { return Columns.size(); }
|
||||
unsigned getNumRows() const {
|
||||
assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
|
||||
return cast<VectorType>(Columns[0]->getType())->getNumElements();
|
||||
}
|
||||
|
@ -634,10 +634,11 @@ public:
|
|||
return true;
|
||||
}
|
||||
|
||||
void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
|
||||
ShapeInfo Shape) {
|
||||
IRBuilder<> Builder(Inst);
|
||||
auto VType = cast<VectorType>(Inst->getType());
|
||||
/// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
|
||||
/// columns.
|
||||
ColumnMatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride,
|
||||
ShapeInfo Shape, IRBuilder<> &Builder) {
|
||||
auto VType = cast<VectorType>(Ty);
|
||||
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
|
||||
ColumnMatrixTy Result;
|
||||
// Distance between start of one column and the start of the next
|
||||
|
@ -648,10 +649,41 @@ public:
|
|||
Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
|
||||
Result.addColumn(Column);
|
||||
}
|
||||
return Result.addNumLoads(getNumOps(Result.getColumnTy()) *
|
||||
Result.getNumColumns());
|
||||
}
|
||||
|
||||
/// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
|
||||
/// starting at \p MatrixPtr[I][J].
|
||||
ColumnMatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I,
|
||||
unsigned J, ShapeInfo ResultShape, Type *EltTy,
|
||||
IRBuilder<> &Builder) {
|
||||
|
||||
Value *Offset = Builder.CreateAdd(
|
||||
Builder.CreateMul(Builder.getInt32(J),
|
||||
Builder.getInt32(MatrixShape.NumRows)),
|
||||
Builder.getInt32(I));
|
||||
|
||||
unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
|
||||
Value *EltPtr =
|
||||
Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
|
||||
Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
|
||||
Type *TileTy =
|
||||
VectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns);
|
||||
Type *TilePtrTy = PointerType::get(TileTy, AS);
|
||||
Value *TilePtr =
|
||||
Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
|
||||
|
||||
return loadMatrix(TileTy, TilePtr, Builder.getInt32(ResultShape.NumRows),
|
||||
ResultShape, Builder);
|
||||
}
|
||||
|
||||
/// Lower a load instruction with shape information.
|
||||
void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
|
||||
ShapeInfo Shape) {
|
||||
IRBuilder<> Builder(Inst);
|
||||
finalizeLowering(Inst,
|
||||
Result.addNumLoads(getNumOps(Result.getColumnTy()) *
|
||||
Result.getNumColumns()),
|
||||
loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder),
|
||||
Builder);
|
||||
}
|
||||
|
||||
|
@ -665,22 +697,54 @@ public:
|
|||
{Inst->getArgOperand(2), Inst->getArgOperand(3)});
|
||||
}
|
||||
|
||||
/// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
|
||||
/// MatrixPtr[I][J].
|
||||
void storeMatrix(const ColumnMatrixTy &StoreVal, Value *MatrixPtr,
|
||||
ShapeInfo MatrixShape, unsigned I, unsigned J, Type *EltTy,
|
||||
IRBuilder<> &Builder) {
|
||||
Value *Offset = Builder.CreateAdd(
|
||||
Builder.CreateMul(Builder.getInt32(J),
|
||||
Builder.getInt32(MatrixShape.NumRows)),
|
||||
Builder.getInt32(I));
|
||||
|
||||
unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
|
||||
Value *EltPtr =
|
||||
Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
|
||||
Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
|
||||
Type *TileTy = VectorType::get(EltTy, StoreVal.getNumRows() *
|
||||
StoreVal.getNumColumns());
|
||||
Type *TilePtrTy = PointerType::get(TileTy, AS);
|
||||
Value *TilePtr =
|
||||
Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
|
||||
|
||||
storeMatrix(TileTy, StoreVal, TilePtr,
|
||||
Builder.getInt32(StoreVal.getNumRows()), Builder);
|
||||
}
|
||||
|
||||
/// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
|
||||
/// columns.
|
||||
ColumnMatrixTy storeMatrix(Type *Ty, ColumnMatrixTy StoreVal, Value *Ptr,
|
||||
Value *Stride, IRBuilder<> &Builder) {
|
||||
auto VType = cast<VectorType>(Ty);
|
||||
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
|
||||
for (auto C : enumerate(StoreVal.columns())) {
|
||||
Value *GEP = computeColumnAddr(EltPtr, Builder.getInt32(C.index()),
|
||||
Stride, StoreVal.getNumRows(),
|
||||
VType->getElementType(), Builder);
|
||||
createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
|
||||
}
|
||||
return ColumnMatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) *
|
||||
StoreVal.getNumColumns());
|
||||
}
|
||||
|
||||
/// Lower a store instruction with shape information.
|
||||
void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
|
||||
ShapeInfo Shape) {
|
||||
IRBuilder<> Builder(Inst);
|
||||
auto VType = cast<VectorType>(Matrix->getType());
|
||||
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
|
||||
auto LM = getMatrix(Matrix, Shape, Builder);
|
||||
for (auto C : enumerate(LM.columns())) {
|
||||
Value *GEP =
|
||||
computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
|
||||
Shape.NumRows, VType->getElementType(), Builder);
|
||||
createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
|
||||
}
|
||||
Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores(
|
||||
getNumOps(LM.getColumnTy()) * LM.getNumColumns());
|
||||
|
||||
ToRemove.push_back(Inst);
|
||||
auto StoreVal = getMatrix(Matrix, Shape, Builder);
|
||||
finalizeLowering(
|
||||
Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder),
|
||||
Builder);
|
||||
}
|
||||
|
||||
/// Lowers llvm.matrix.columnwise.store.
|
||||
|
|
Loading…
Reference in New Issue