From 043ce4e6bdd376ff460d78446d1a6b94c6e0f18c Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Fri, 25 Jun 2021 23:37:53 +0000 Subject: [PATCH] [MLIR][Sparse] Move `buildLattices` into Merger This allows us to use `buildLattices` in the `Merger` unittests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D104879 --- .../mlir/Dialect/SparseTensor/Utils/Merger.h | 8 ++++- .../Transforms/Sparsification.cpp | 33 +------------------ .../lib/Dialect/SparseTensor/Utils/Merger.cpp | 27 +++++++++++++++ 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index 9457276ba874..cbb0aede83f8 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -83,7 +83,7 @@ public: /// additional synthetic tensor at the end of this set to represent all /// invariant expressions in the kernel. Merger(unsigned t, unsigned l) - : outTensor(t - 1), numTensors(t + 1), numLoops(l), + : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l), dims(t + 1, std::vector(l, Dim::kUndef)) {} /// Adds a tensor expression. Returns its index. @@ -148,6 +148,11 @@ public: /// Returns true if any set bit corresponds to queried dim. bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const; + /// Builds the iteration lattices in a bottom-up traversal given the remaining + /// tensor (sub)expression and the next loop index in the iteration graph. + /// Returns index of the root expression. + unsigned buildLattices(unsigned exp, unsigned idx); + /// Setter void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } @@ -166,6 +171,7 @@ public: private: const unsigned outTensor; + const unsigned syntheticTensor; const unsigned numTensors; const unsigned numLoops; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index f12aaccb3169..2b6d5d1caf15 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -302,37 +302,6 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, return false; } -/// Builds the iteration lattices in a bottom-up traversal given the remaining -/// tensor (sub)expression and the next loop index in the iteration graph. -static unsigned buildLattices(Merger &merger, linalg::GenericOp op, - unsigned exp, unsigned idx) { - Kind kind = merger.exp(exp).kind; - if (kind == Kind::kTensor || kind == Kind::kInvariant) { - // Either the index is really used in the tensor expression, or it is - // set to the undefined index in that dimension. An invariant expression - // is set to a synthetic tensor with undefined indices only. - unsigned s = merger.addSet(); - unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 - : op.getNumInputsAndOutputs(); - merger.set(s).push_back(merger.addLat(t, idx, exp)); - return s; - } - unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); - unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); - switch (kind) { - case Kind::kTensor: - case Kind::kInvariant: - llvm_unreachable("handled above"); - case Kind::kMulF: - case Kind::kMulI: - return merger.takeConj(kind, s0, s1); - case Kind::kAddF: - case Kind::kAddI: - return merger.takeDisj(kind, s0, s1); - } - llvm_unreachable("unexpected expression kind"); -} - /// Maps sparse integer option to actual integral storage type. static Type genIntType(PatternRewriter &rewriter, unsigned width) { if (width == 0) @@ -1121,7 +1090,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, // in play for a non-singleton loop sequence. Location loc = op.getLoc(); unsigned idx = topSort[at]; - unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); + unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); unsigned lsize = merger.set(lts).size(); assert(lsize != 0); unsigned l0 = merger.set(lts)[0]; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index af864b764a6e..3d63246e950f 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -137,6 +137,33 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { return false; } +unsigned Merger::buildLattices(unsigned e, unsigned idx) { + Kind kind = exp(e).kind; + if (kind == Kind::kTensor || kind == Kind::kInvariant) { + // Either the index is really used in the tensor expression, or it is + // set to the undefined index in that dimension. An invariant expression + // is set to a synthetic tensor with undefined indices only. + unsigned s = addSet(); + unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor; + set(s).push_back(addLat(t, idx, e)); + return s; + } + unsigned s0 = buildLattices(exp(e).e0, idx); + unsigned s1 = buildLattices(exp(e).e1, idx); + switch (kind) { + case Kind::kTensor: + case Kind::kInvariant: + llvm_unreachable("handled above"); + case Kind::kMulF: + case Kind::kMulI: + return takeConj(kind, s0, s1); + case Kind::kAddF: + case Kind::kAddI: + return takeDisj(kind, s0, s1); + } + llvm_unreachable("unexpected expression kind"); +} + #ifndef NDEBUG //