[mlir][sparse] hoist loop invariant tensor loads in sparse compiler

After bufferization, the backend has much more trouble hoisting loop invariant
loads from the loops generated by the sparse compiler. Therefore, this is done
during sparse code generation. Note that we don't bother hoisting derived
invariant expressions on SSA values, since the backend does that very well.

Still TBD: scalarize reductions to avoid load-add-store cycles

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D92534
This commit is contained in:
Aart Bik 2020-12-07 11:54:58 -08:00
parent 1c98f98410
commit 74cd9e587d
3 changed files with 161 additions and 42 deletions

View File

@ -59,14 +59,21 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
/// children tensor expressions.
struct TensorExp {
TensorExp(Kind k, unsigned x, unsigned y, Value v)
: kind(k), e0(x), e1(y), val(v) {}
: kind(k), e0(x), e1(y), val(v) {
assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
(kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
(kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
}
Kind kind;
/// Indices of children expression(s).
unsigned e0;
unsigned e1;
/// Direct link to IR for an invariant. During code generation,
/// field is used to cache "hoisted" loop invariant tensor loads.
Value val;
};
/// Lattice point. Each lattice point consist of a conjunction of tensor
/// Lattice point. Each lattice point consists of a conjunction of tensor
/// loop indices (encoded in a bitvector) and the index of the corresponding
/// tensor expression.
struct LatPoint {
@ -74,7 +81,9 @@ struct LatPoint {
bits.set(b);
}
LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
/// Conjunction of tensor loop indices as bitvector.
llvm::BitVector bits;
/// Index of the tensor expresssion.
unsigned exp;
};
@ -502,8 +511,16 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
/// Generates a load on a dense or sparse tensor.
static Value genTensorLoad(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
unsigned tensor) {
unsigned exp) {
// Test if the load was hoisted to a higher loop nest.
Value val = merger.exp(exp).val;
if (val) {
merger.exp(exp).val = Value(); // reset
return val;
}
// Actual load.
SmallVector<Value, 4> args;
unsigned tensor = merger.exp(exp).e0;
auto map = op.getIndexingMap(tensor);
bool sparse = false;
for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
@ -515,7 +532,9 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
args.push_back(codegen.pidxs[tensor][idx]); // position index
}
}
return rewriter.create<LoadOp>(op.getLoc(), codegen.buffers[tensor], args);
Location loc = op.getLoc();
Value ptr = codegen.buffers[tensor];
return rewriter.create<LoadOp>(loc, ptr, args);
}
/// Generates a store on a dense tensor.
@ -528,25 +547,33 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
unsigned idx = map.getDimPosition(i);
args.push_back(codegen.loops[idx]); // universal dense index
}
rewriter.create<StoreOp>(op.getLoc(), rhs, codegen.buffers[tensor], args);
Location loc = op.getLoc();
Value ptr = codegen.buffers[tensor];
rewriter.create<StoreOp>(loc, rhs, ptr, args);
}
/// Generates a pointer/index load from the sparse storage scheme.
static Value genIntLoad(PatternRewriter &rewriter, Location loc, Value ptr,
Value s) {
static Value genLoad(PatternRewriter &rewriter, Location loc, Value ptr,
Value s) {
Value load = rewriter.create<LoadOp>(loc, ptr, s);
return load.getType().isa<IndexType>()
? load
: rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
}
/// Generates an invariant value.
static Value genInvariantValue(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, unsigned exp) {
return merger.exp(exp).val;
}
/// Recursively generates tensor expression.
static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
linalg::GenericOp op, unsigned exp) {
if (merger.exp(exp).kind == Kind::kTensor)
return genTensorLoad(merger, codegen, rewriter, op, merger.exp(exp).e0);
return genTensorLoad(merger, codegen, rewriter, op, exp);
else if (merger.exp(exp).kind == Kind::kInvariant)
return merger.exp(exp).val;
return genInvariantValue(merger, codegen, rewriter, exp);
Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
switch (merger.exp(exp).kind) {
@ -564,6 +591,33 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
}
}
/// Hoists loop invariant tensor loads for which indices have been exhausted.
static void genInvariants(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
unsigned exp) {
if (merger.exp(exp).kind == Kind::kTensor) {
unsigned lhs = op.getNumInputsAndOutputs() - 1;
unsigned tensor = merger.exp(exp).e0;
if (tensor == lhs)
return; // TODO: scalarize reduction as well (using scf.yield)
auto map = op.getIndexingMap(tensor);
for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
unsigned idx = map.getDimPosition(i);
if (!codegen.loops[idx])
return; // still in play
}
// All exhausted at this level.
merger.exp(exp).val = genTensorLoad(merger, codegen, rewriter, op, exp);
} else if (merger.exp(exp).kind != Kind::kInvariant) {
// Traverse into the binary operations. Note that we only hoist
// tensor loads, since subsequent MLIR/LLVM passes know how to
// deal with all other kinds of derived loop invariants.
genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e0);
genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e1);
}
}
/// Generates initialization code for the subsequent loop sequence at
/// current index level. Returns true if the loop sequence needs to
/// maintain the universal index.
@ -590,9 +644,9 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
: codegen.pidxs[tensor][topSort[pat - 1]];
codegen.pidxs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p0);
codegen.pidxs[tensor][idx] = genLoad(rewriter, loc, ptr, p0);
Value p1 = rewriter.create<AddIOp>(loc, p0, one);
codegen.highs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p1);
codegen.highs[tensor][idx] = genLoad(rewriter, loc, ptr, p1);
} else {
// Dense index still in play.
needsUniv = true;
@ -608,7 +662,8 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
/// Generates a for-loop on a single index.
static Operation *genFor(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
bool isOuter, unsigned idx, llvm::BitVector &indices) {
bool isOuter, bool isInner, unsigned idx,
llvm::BitVector &indices) {
unsigned fb = indices.find_first();
unsigned tensor = merger.tensor(fb);
assert(idx == merger.index(fb));
@ -725,10 +780,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
/// singleton iteration or co-iteration over the given conjunction.
static Operation *genLoop(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
bool isOuter, unsigned idx, bool needsUniv,
llvm::BitVector &indices) {
if (indices.count() == 1)
return genFor(merger, codegen, rewriter, op, isOuter, idx, indices);
std::vector<unsigned> &topSort, unsigned at,
bool needsUniv, llvm::BitVector &indices) {
unsigned idx = topSort[at];
if (indices.count() == 1) {
bool isOuter = at == 0;
bool isInner = at == topSort.size() - 1;
return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
indices);
}
return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
}
@ -749,7 +809,7 @@ static void genLocals(Merger &merger, CodeGen &codegen,
assert(idx == merger.index(b));
Value ptr = codegen.indices[tensor][idx];
Value s = codegen.pidxs[tensor][idx];
Value load = genIntLoad(rewriter, loc, ptr, s);
Value load = genLoad(rewriter, loc, ptr, s);
codegen.idxs[tensor][idx] = load;
if (!needsUniv) {
if (min) {
@ -886,6 +946,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
assert(lsize != 0);
unsigned l0 = merger.set(lts)[0];
LatPoint lat0 = merger.lat(l0);
genInvariants(merger, codegen, rewriter, op, exp);
bool needsUniv =
genInit(merger, codegen, rewriter, op, topSort, at, lat0.bits) &&
lsize > 1;
@ -897,9 +958,8 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
// Emit loop.
llvm::BitVector indices = lati.bits;
optimizeIndices(merger, lsize, indices);
bool isOuter = at == 0;
Operation *loop = genLoop(merger, codegen, rewriter, op, isOuter, idx,
needsUniv, indices);
Operation *loop =
genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits);
// Visit all lattices points with Li >= Lj to generate the
@ -931,6 +991,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
}
rewriter.setInsertionPointAfter(loop);
}
codegen.loops[idx] = Value();
}
namespace {

View File

@ -1071,8 +1071,8 @@ func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf
}
// CHECK-LABEL: func @sum_reduction(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<10x20xf32>,
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<f32>) -> tensor<f32> {
// CHECK: %[[VAL_2:.*]] = constant 999 : index
// CHECK: %[[VAL_3:.*]] = constant 10 : index
// CHECK: %[[VAL_4:.*]] = constant 0 : index
@ -1200,19 +1200,19 @@ func @scale(%arga: tensor<?x?xf64>) -> tensor<?x?xf64> {
// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_21]] to %[[VAL_22]] step %[[VAL_6]] {
// CHECK: %[[VAL_24:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_5]] to %[[VAL_15]] step %[[VAL_6]] {
// CHECK: %[[VAL_26:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = addi %[[VAL_23]], %[[VAL_6]] : index
// CHECK: %[[VAL_28:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_6]] {
// CHECK: %[[VAL_30:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_29]]] : memref<?xindex>
// CHECK: %[[VAL_31:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_30]]] : memref<?x?xf32>
// CHECK: %[[VAL_32:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
// CHECK: %[[VAL_33:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_25]]] : memref<?x?xf32>
// CHECK: %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_30]]] : memref<?x?xf32>
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_33]], %[[VAL_34]] : f32
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_32]], %[[VAL_35]] : f32
// CHECK: %[[VAL_37:.*]] = addf %[[VAL_31]], %[[VAL_36]] : f32
// CHECK: store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_30]]] : memref<?x?xf32>
// CHECK: %[[VAL_26:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_25]]] : memref<?x?xf32>
// CHECK: %[[VAL_27:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_23]], %[[VAL_6]] : index
// CHECK: %[[VAL_29:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_28]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_6]] {
// CHECK: %[[VAL_31:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<?xindex>
// CHECK: %[[VAL_32:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_31]]] : memref<?x?xf32>
// CHECK: %[[VAL_33:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf32>
// CHECK: %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_31]]] : memref<?x?xf32>
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_26]], %[[VAL_34]] : f32
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_33]], %[[VAL_35]] : f32
// CHECK: %[[VAL_37:.*]] = addf %[[VAL_32]], %[[VAL_36]] : f32
// CHECK: store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_31]]] : memref<?x?xf32>
// CHECK: }
// CHECK: }
// CHECK: }

View File

@ -1192,15 +1192,15 @@ func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
// CHECK: %[[VAL_25:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_6]] {
// CHECK: %[[VAL_27:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_5]] to %[[VAL_17]] step %[[VAL_6]] {
// CHECK: %[[VAL_29:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xf32>
// CHECK: %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_28]]] : memref<?x?xf32>
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_29]], %[[VAL_30]] : f32
// CHECK: %[[VAL_32:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_27]], %[[VAL_28]]] : memref<?x?xf32>
// CHECK: %[[VAL_28:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xf32>
// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_5]] to %[[VAL_17]] step %[[VAL_6]] {
// CHECK: %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<?x?xf32>
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_28]], %[[VAL_30]] : f32
// CHECK: %[[VAL_32:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_27]], %[[VAL_29]]] : memref<?x?xf32>
// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_31]], %[[VAL_32]] : f32
// CHECK: %[[VAL_34:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : memref<?x?xf32>
// CHECK: %[[VAL_34:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_29]]] : memref<?x?xf32>
// CHECK: %[[VAL_35:.*]] = addf %[[VAL_33]], %[[VAL_34]] : f32
// CHECK: store %[[VAL_35]], %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : memref<?x?xf32>
// CHECK: store %[[VAL_35]], %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_29]]] : memref<?x?xf32>
// CHECK: }
// CHECK: }
// CHECK: }
@ -1281,3 +1281,61 @@ func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f
} -> tensor<f32>
return %0 : tensor<f32>
}
#trait_invariants = {
indexing_maps = [
affine_map<(i,j,k) -> (i)>, // a
affine_map<(i,j,k) -> (j)>, // b
affine_map<(i,j,k) -> (k)>, // c
affine_map<(i,j,k) -> (i,j,k)> // x
],
sparse = [
[ "D" ], // a
[ "D" ], // b
[ "D" ], // c
[ "D", "D", "D" ] // x
],
iterator_types = ["parallel", "parallel", "parallel"],
doc = "x(i,j,k) = a(i) * b(j) * c(k)"
}
// CHECK-LABEL: func @invariants(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<20xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<30xf32>) -> tensor<10x20x30xf32> {
// CHECK: %[[VAL_3:.*]] = constant 10 : index
// CHECK: %[[VAL_4:.*]] = constant 20 : index
// CHECK: %[[VAL_5:.*]] = constant 30 : index
// CHECK: %[[VAL_6:.*]] = constant 0 : index
// CHECK: %[[VAL_7:.*]] = constant 1 : index
// CHECK: %[[VAL_8:.*]] = alloca() : memref<10xf32>
// CHECK: %[[VAL_9:.*]] = alloca() : memref<20xf32>
// CHECK: %[[VAL_10:.*]] = alloca() : memref<30xf32>
// CHECK: %[[VAL_11:.*]] = alloca() : memref<10x20x30xf32>
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
// CHECK: %[[VAL_13:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<10xf32>
// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<20xf32>
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_13]], %[[VAL_15]] : f32
// CHECK: %[[VAL_18:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<30xf32>
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_17]], %[[VAL_18]] : f32
// CHECK: store %[[VAL_19]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_14]], %[[VAL_16]]] : memref<10x20x30xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: %[[VAL_20:.*]] = tensor_load %[[VAL_11]] : memref<10x20x30xf32>
// CHECK: return %[[VAL_20]] : tensor<10x20x30xf32>
// CHECK: }
func @invariants(%arga: tensor<10xf32>,
%argb: tensor<20xf32>,
%argc: tensor<30xf32>) -> tensor<10x20x30xf32> {
%0 = linalg.generic #trait_invariants
ins(%arga, %argb, %argc : tensor<10xf32>, tensor<20xf32>, tensor<30xf32>) {
^bb(%a : f32, %b : f32, %c : f32):
%0 = mulf %a, %b : f32
%1 = mulf %0, %c : f32
linalg.yield %1: f32
} -> tensor<10x20x30xf32>
return %0 : tensor<10x20x30xf32>
}