[Matrix] Hoist load/store generation logic, add helpers for tiled access.

This patch slightly generalizes the code to emit loads and stores of a
matrix and adds helpers to load/store a tile of a larger matrix.

This will be used in a follow-up patch introducing initial tiling.

Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D75564
This commit is contained in:
Florian Hahn 2020-03-19 19:15:23 +00:00
parent c2586cab89
commit 0cc2d23751
1 changed files with 85 additions and 21 deletions

View File

@ -181,8 +181,8 @@ class LowerMatrixIntrinsics {
void setColumn(unsigned i, Value *V) { Columns[i] = V; }
size_t getNumColumns() const { return Columns.size(); }
size_t getNumRows() const {
unsigned getNumColumns() const { return Columns.size(); }
unsigned getNumRows() const {
assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
return cast<VectorType>(Columns[0]->getType())->getNumElements();
}
@ -634,10 +634,11 @@ public:
return true;
}
void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
auto VType = cast<VectorType>(Inst->getType());
/// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
/// columns.
ColumnMatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride,
ShapeInfo Shape, IRBuilder<> &Builder) {
auto VType = cast<VectorType>(Ty);
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
ColumnMatrixTy Result;
// Distance between start of one column and the start of the next
@ -648,10 +649,41 @@ public:
Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
Result.addColumn(Column);
}
return Result.addNumLoads(getNumOps(Result.getColumnTy()) *
Result.getNumColumns());
}
/// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
/// starting at \p MatrixPtr[I][J].
ColumnMatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I,
unsigned J, ShapeInfo ResultShape, Type *EltTy,
IRBuilder<> &Builder) {
Value *Offset = Builder.CreateAdd(
Builder.CreateMul(Builder.getInt32(J),
Builder.getInt32(MatrixShape.NumRows)),
Builder.getInt32(I));
unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
Value *EltPtr =
Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
Type *TileTy =
VectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns);
Type *TilePtrTy = PointerType::get(TileTy, AS);
Value *TilePtr =
Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
return loadMatrix(TileTy, TilePtr, Builder.getInt32(ResultShape.NumRows),
ResultShape, Builder);
}
/// Lower a load instruction with shape information.
void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
finalizeLowering(Inst,
Result.addNumLoads(getNumOps(Result.getColumnTy()) *
Result.getNumColumns()),
loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder),
Builder);
}
@ -665,22 +697,54 @@ public:
{Inst->getArgOperand(2), Inst->getArgOperand(3)});
}
/// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
/// MatrixPtr[I][J].
void storeMatrix(const ColumnMatrixTy &StoreVal, Value *MatrixPtr,
ShapeInfo MatrixShape, unsigned I, unsigned J, Type *EltTy,
IRBuilder<> &Builder) {
Value *Offset = Builder.CreateAdd(
Builder.CreateMul(Builder.getInt32(J),
Builder.getInt32(MatrixShape.NumRows)),
Builder.getInt32(I));
unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
Value *EltPtr =
Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
Type *TileTy = VectorType::get(EltTy, StoreVal.getNumRows() *
StoreVal.getNumColumns());
Type *TilePtrTy = PointerType::get(TileTy, AS);
Value *TilePtr =
Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
storeMatrix(TileTy, StoreVal, TilePtr,
Builder.getInt32(StoreVal.getNumRows()), Builder);
}
/// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
/// columns.
ColumnMatrixTy storeMatrix(Type *Ty, ColumnMatrixTy StoreVal, Value *Ptr,
Value *Stride, IRBuilder<> &Builder) {
auto VType = cast<VectorType>(Ty);
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
for (auto C : enumerate(StoreVal.columns())) {
Value *GEP = computeColumnAddr(EltPtr, Builder.getInt32(C.index()),
Stride, StoreVal.getNumRows(),
VType->getElementType(), Builder);
createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
}
return ColumnMatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) *
StoreVal.getNumColumns());
}
/// Lower a store instruction with shape information.
void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
auto VType = cast<VectorType>(Matrix->getType());
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
auto LM = getMatrix(Matrix, Shape, Builder);
for (auto C : enumerate(LM.columns())) {
Value *GEP =
computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
Shape.NumRows, VType->getElementType(), Builder);
createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
}
Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores(
getNumOps(LM.getColumnTy()) * LM.getNumColumns());
ToRemove.push_back(Inst);
auto StoreVal = getMatrix(Matrix, Shape, Builder);
finalizeLowering(
Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder),
Builder);
}
/// Lowers llvm.matrix.columnwise.store.