[MLIR][Sparse] Move `buildLattices` into Merger

This allows us to use `buildLattices` in the `Merger` unittests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D104879
This commit is contained in:
Gus Smith 2021-06-25 23:37:53 +00:00
parent e074d580b2
commit 043ce4e6bd
3 changed files with 35 additions and 33 deletions

View File

@ -83,7 +83,7 @@ public:
/// additional synthetic tensor at the end of this set to represent all
/// invariant expressions in the kernel.
Merger(unsigned t, unsigned l)
: outTensor(t - 1), numTensors(t + 1), numLoops(l),
: outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
/// Adds a tensor expression. Returns its index.
@ -148,6 +148,11 @@ public:
/// Returns true if any set bit corresponds to queried dim.
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
/// Builds the iteration lattices in a bottom-up traversal given the remaining
/// tensor (sub)expression and the next loop index in the iteration graph.
/// Returns index of the root expression.
unsigned buildLattices(unsigned exp, unsigned idx);
/// Setter
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
@ -166,6 +171,7 @@ public:
private:
const unsigned outTensor;
const unsigned syntheticTensor;
const unsigned numTensors;
const unsigned numLoops;

View File

@ -302,37 +302,6 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
return false;
}
/// Builds the iteration lattices in a bottom-up traversal given the remaining
/// tensor (sub)expression and the next loop index in the iteration graph.
static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
unsigned exp, unsigned idx) {
Kind kind = merger.exp(exp).kind;
if (kind == Kind::kTensor || kind == Kind::kInvariant) {
// Either the index is really used in the tensor expression, or it is
// set to the undefined index in that dimension. An invariant expression
// is set to a synthetic tensor with undefined indices only.
unsigned s = merger.addSet();
unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
: op.getNumInputsAndOutputs();
merger.set(s).push_back(merger.addLat(t, idx, exp));
return s;
}
unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
switch (kind) {
case Kind::kTensor:
case Kind::kInvariant:
llvm_unreachable("handled above");
case Kind::kMulF:
case Kind::kMulI:
return merger.takeConj(kind, s0, s1);
case Kind::kAddF:
case Kind::kAddI:
return merger.takeDisj(kind, s0, s1);
}
llvm_unreachable("unexpected expression kind");
}
/// Maps sparse integer option to actual integral storage type.
static Type genIntType(PatternRewriter &rewriter, unsigned width) {
if (width == 0)
@ -1121,7 +1090,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
// in play for a non-singleton loop sequence.
Location loc = op.getLoc();
unsigned idx = topSort[at];
unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
unsigned lsize = merger.set(lts).size();
assert(lsize != 0);
unsigned l0 = merger.set(lts)[0];

View File

@ -137,6 +137,33 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
return false;
}
unsigned Merger::buildLattices(unsigned e, unsigned idx) {
Kind kind = exp(e).kind;
if (kind == Kind::kTensor || kind == Kind::kInvariant) {
// Either the index is really used in the tensor expression, or it is
// set to the undefined index in that dimension. An invariant expression
// is set to a synthetic tensor with undefined indices only.
unsigned s = addSet();
unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor;
set(s).push_back(addLat(t, idx, e));
return s;
}
unsigned s0 = buildLattices(exp(e).e0, idx);
unsigned s1 = buildLattices(exp(e).e1, idx);
switch (kind) {
case Kind::kTensor:
case Kind::kInvariant:
llvm_unreachable("handled above");
case Kind::kMulF:
case Kind::kMulI:
return takeConj(kind, s0, s1);
case Kind::kAddF:
case Kind::kAddI:
return takeDisj(kind, s0, s1);
}
llvm_unreachable("unexpected expression kind");
}
#ifndef NDEBUG
//