diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h new file mode 100644 index 000000000000..0ffd00131fd4 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -0,0 +1,163 @@ +//===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for dealing with iteration lattices. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ +#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ + +#include "mlir/IR/Value.h" +#include "llvm/ADT/BitVector.h" + +namespace mlir { +namespace sparse_tensor { + +enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; +enum class Dim { kSparse, kDense, kSingle, kUndef }; + +/// Tensor expression. Represents a MLIR expression in tensor index notation. +/// For tensors, e0 denotes the tensor index. For invariants, the IR value is +/// stored directly. For binary operations, e0 and e1 denote the index of the +/// children tensor expressions. +struct TensorExp { + TensorExp(Kind k, unsigned x, unsigned y, Value v) + : kind(k), e0(x), e1(y), val(v) { + assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || + (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || + (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); + } + Kind kind; + /// Indices of children expression(s). + unsigned e0; + unsigned e1; + /// Direct link to IR for an invariant. During code generation, + /// field is used to cache "hoisted" loop invariant tensor loads. + Value val; +}; + +/// Lattice point. Each lattice point consists of a conjunction of tensor +/// loop indices (encoded in a bitvector) and the index of the corresponding +/// tensor expression. +struct LatPoint { + LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) { + bits.set(b); + } + LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} + /// Conjunction of tensor loop indices as bitvector. This represents + /// all indices involved in the tensor expression + llvm::BitVector bits; + /// Simplified conjunction of tensor loop indices as bitvector. This + /// represents a simplified condition under which this tensor expression + /// must execute. Pre-computed during codegen to avoid repeated eval. + llvm::BitVector simple; + /// Index of the tensor expresssion. + unsigned exp; +}; + +/// A class to handle all iteration lattice operations. This class abstracts +/// away from some implementation details of storing iteration lattices and +/// tensor expressions. This allows for fine-tuning performance characteristics +/// independently from the basic algorithm if bottlenecks are identified. +class Merger { +public: + /// Constructs a merger for the given number of tensors and loops. The + /// user supplies the number of tensors involved in the kernel, with the + /// last tensor in this set denoting the output tensor. The merger adds an + /// additional synthetic tensor at the end of this set to represent all + /// invariant expressions in the kernel. + Merger(unsigned t, unsigned l) + : outTensor(t - 1), numTensors(t + 1), numLoops(l), + dims(t + 1, std::vector(l, Dim::kUndef)) {} + + /// Adds a tensor expression. Returns its index. + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()); + unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } + + /// Adds an iteration lattice point. Returns its index. + unsigned addLat(unsigned t, unsigned i, unsigned e); + + /// Adds a new, initially empty, set. Returns its index. + unsigned addSet(); + + /// Computes a single conjunction of two lattice points by taking the "union" + /// of loop indices (effectively constructing a larger "intersection" of those + /// indices) with a newly constructed tensor (sub)expression of given kind. + /// Returns the index of the new lattice point. + unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1); + + /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of + /// cartesian product. Returns the index of the new set. + unsigned takeConj(Kind kind, unsigned s0, unsigned s1); + + /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). + /// Returns the index of the new set. + unsigned takeDisj(Kind kind, unsigned s0, unsigned s1); + + /// Optimizes the iteration lattice points in the given set. This + /// method should be called right before code generation to avoid + /// generating redundant loops and conditions. + unsigned optimizeSet(unsigned s0); + + /// Simplifies the conditions in a conjunction of a given lattice point + /// within the given set using just two basic rules: + /// (1) multiple dense conditions are reduced to single dense, and + /// (2) a *singleton* sparse/dense is reduced to sparse/random access. + llvm::BitVector simplifyCond(unsigned s, unsigned p0); + + /// Returns true if Li > Lj. + bool latGT(unsigned i, unsigned j) const; + + /// Returns true if Li and Lj only differ in dense. + bool onlyDenseDiff(unsigned i, unsigned j); + + /// Bit translation. + unsigned tensor(unsigned b) const { return b % numTensors; } + unsigned index(unsigned b) const { return b / numTensors; } + + /// Returns true if bit corresponds to queried dim. + bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } + + /// Returns true if bit corresponds to index of output tensor. + bool isOutTensor(unsigned b, unsigned i) const { + return tensor(b) == outTensor && index(b) == i; + } + + /// Returns true if tensor access at given index has queried dim. + bool isDim(unsigned t, unsigned i, Dim d) const { + assert(t < numTensors && i < numLoops); + return dims[t][i] == d; + } + + /// Returns true if any set bit corresponds to queried dim. + bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const; + + /// Setter + void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } + + /// Getters. + TensorExp &exp(unsigned e) { return tensorExps[e]; } + LatPoint &lat(unsigned l) { return latPoints[l]; } + SmallVector &set(unsigned s) { return latSets[s]; } + +private: + const unsigned outTensor; + const unsigned numTensors; + const unsigned numLoops; + + std::vector> dims; + llvm::SmallVector tensorExps; + llvm::SmallVector latPoints; + llvm::SmallVector, 8> latSets; +}; + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ diff --git a/mlir/lib/Dialect/SparseTensor/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/CMakeLists.txt index 9f57627c321f..31167e6af908 100644 --- a/mlir/lib/Dialect/SparseTensor/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt index 68adb6fe1db1..24600aace642 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms MLIRSCF MLIRStandard MLIRSparseTensor + MLIRSparseTensorUtils MLIRTensor MLIRTransforms MLIRVector diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 9c406a36f072..f12aaccb3169 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -47,6 +47,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Matchers.h" @@ -58,245 +59,6 @@ using namespace mlir::sparse_tensor; namespace { -enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; -enum class Dim { kSparse, kDense, kSingle, kUndef }; - -/// Tensor expression. Represents a MLIR expression in tensor index notation. -/// For tensors, e0 denotes the tensor index. For invariants, the IR value is -/// stored directly. For binary operations, e0 and e1 denote the index of the -/// children tensor expressions. -struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v) - : kind(k), e0(x), e1(y), val(v) { - assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || - (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || - (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); - } - Kind kind; - /// Indices of children expression(s). - unsigned e0; - unsigned e1; - /// Direct link to IR for an invariant. During code generation, - /// field is used to cache "hoisted" loop invariant tensor loads. - Value val; -}; - -/// Lattice point. Each lattice point consists of a conjunction of tensor -/// loop indices (encoded in a bitvector) and the index of the corresponding -/// tensor expression. -struct LatPoint { - LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) { - bits.set(b); - } - LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} - /// Conjunction of tensor loop indices as bitvector. This represents - /// all indices involved in the tensor expression - llvm::BitVector bits; - /// Simplified conjunction of tensor loop indices as bitvector. This - /// represents a simplified condition under which this tensor expression - /// must execute. Pre-computed during codegen to avoid repeated eval. - llvm::BitVector simple; - /// Index of the tensor expresssion. - unsigned exp; -}; - -/// A class to handle all iteration lattice operations. This class abstracts -/// away from some implementation details of storing iteration lattices and -/// tensor expressions. This allows for fine-tuning performance characteristics -/// independently from the basic algorithm if bottlenecks are identified. -class Merger { -public: - /// Constructs a merger for the given number of tensors and loops. The - /// user supplies the number of tensors involved in the kernel, with the - /// last tensor in this set denoting the output tensor. The merger adds an - /// additional synthetic tensor at the end of this set to represent all - /// invariant expressions in the kernel. - Merger(unsigned t, unsigned l) - : outTensor(t - 1), numTensors(t + 1), numLoops(l), - dims(t + 1, std::vector(l, Dim::kUndef)) {} - - /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { - unsigned e = tensorExps.size(); - tensorExps.push_back(TensorExp(k, e0, e1, v)); - return e; - } - unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } - - /// Adds an iteration lattice point. Returns its index. - unsigned addLat(unsigned t, unsigned i, unsigned e) { - assert(t < numTensors && i < numLoops); - unsigned p = latPoints.size(); - latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); - return p; - } - - /// Adds a new, initially empty, set. Returns its index. - unsigned addSet() { - unsigned s = latSets.size(); - latSets.emplace_back(SmallVector()); - return s; - } - - /// Computes a single conjunction of two lattice points by taking the "union" - /// of loop indices (effectively constructing a larger "intersection" of those - /// indices) with a newly constructed tensor (sub)expression of given kind. - /// Returns the index of the new lattice point. - unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) { - unsigned p = latPoints.size(); - llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); - nb |= latPoints[p1].bits; - unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); - latPoints.push_back(LatPoint(nb, e)); - return p; - } - - /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of - /// cartesian product. Returns the index of the new set. - unsigned takeConj(Kind kind, unsigned s0, unsigned s1) { - unsigned s = addSet(); - for (unsigned p0 : latSets[s0]) - for (unsigned p1 : latSets[s1]) - latSets[s].push_back(conjLatPoint(kind, p0, p1)); - return s; - } - - /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). - /// Returns the index of the new set. - unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) { - unsigned s = takeConj(kind, s0, s1); - for (unsigned p : latSets[s0]) - latSets[s].push_back(p); - for (unsigned p : latSets[s1]) - latSets[s].push_back(p); - return s; - } - - /// Optimizes the iteration lattice points in the given set. This - /// method should be called right before code generation to avoid - /// generating redundant loops and conditions. - unsigned optimizeSet(unsigned s0) { - unsigned s = addSet(); - assert(latSets[s0].size() != 0); - unsigned p0 = latSets[s0][0]; - for (unsigned p1 : latSets[s0]) { - bool add = true; - if (p0 != p1) { - // Is this a straightforward copy? - unsigned e = latPoints[p1].exp; - if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) - continue; - // Conjunction already covered? - for (unsigned p2 : latSets[s]) { - assert(!latGT(p1, p2)); // Lj => Li would be bad - if (onlyDenseDiff(p2, p1)) { - add = false; - break; - } - } - assert(!add || latGT(p0, p1)); - } - if (add) - latSets[s].push_back(p1); - } - for (unsigned p : latSets[s]) - latPoints[p].simple = simplifyCond(s, p); - return s; - } - - /// Simplifies the conditions in a conjunction of a given lattice point - /// within the given set using just two basic rules: - /// (1) multiple dense conditions are reduced to single dense, and - /// (2) a *singleton* sparse/dense is reduced to sparse/random access. - llvm::BitVector simplifyCond(unsigned s, unsigned p0) { - // First determine if this lattice point is a *singleton*, i.e., - // the last point in a lattice, no other is less than this one. - bool isSingleton = true; - for (unsigned p1 : latSets[s]) { - if (p0 != p1 && latGT(p0, p1)) { - isSingleton = false; - break; - } - } - // Now apply the two basic rules. - llvm::BitVector simple = latPoints[p0].bits; - bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); - for (unsigned b = 0, be = simple.size(); b < be; b++) { - if (simple[b] && !isDim(b, Dim::kSparse)) { - if (reset) - simple.reset(b); - reset = true; - } - } - return simple; - } - - /// Returns true if Li > Lj. - bool latGT(unsigned i, unsigned j) const { - const llvm::BitVector &bitsi = latPoints[i].bits; - const llvm::BitVector &bitsj = latPoints[j].bits; - assert(bitsi.size() == bitsj.size()); - if (bitsi.count() > bitsj.count()) { - for (unsigned b = 0, be = bitsj.size(); b < be; b++) - if (bitsj[b] && !bitsi[b]) - return false; - return true; - } - return false; - } - - /// Returns true if Li and Lj only differ in dense. - bool onlyDenseDiff(unsigned i, unsigned j) { - llvm::BitVector tmp = latPoints[j].bits; - tmp ^= latPoints[i].bits; - return !hasAnyDimOf(tmp, Dim::kSparse); - } - - /// Bit translation. - unsigned tensor(unsigned b) const { return b % numTensors; } - unsigned index(unsigned b) const { return b / numTensors; } - - /// Returns true if bit corresponds to queried dim. - bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } - - /// Returns true if bit corresponds to index of output tensor. - bool isOutTensor(unsigned b, unsigned i) const { - return tensor(b) == outTensor && index(b) == i; - } - - /// Returns true if tensor access at given index has queried dim. - bool isDim(unsigned t, unsigned i, Dim d) const { - assert(t < numTensors && i < numLoops); - return dims[t][i] == d; - } - - /// Returns true if any set bit corresponds to queried dim. - bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { - for (unsigned b = 0, be = bits.size(); b < be; b++) - if (bits[b] && isDim(b, d)) - return true; - return false; - } - - /// Setter - void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } - - /// Getters. - TensorExp &exp(unsigned e) { return tensorExps[e]; } - LatPoint &lat(unsigned l) { return latPoints[l]; } - SmallVector &set(unsigned s) { return latSets[s]; } - -private: - const unsigned outTensor; - const unsigned numTensors; - const unsigned numLoops; - - std::vector> dims; - llvm::SmallVector tensorExps; - llvm::SmallVector latPoints; - llvm::SmallVector, 8> latSets; -}; - // Code generation. struct CodeGen { CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt new file mode 100644 index 000000000000..bfd614cb8df4 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_dialect_library(MLIRSparseTensorUtils + Merger.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp new file mode 100644 index 000000000000..0d1d34597afc --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -0,0 +1,138 @@ +//===- Merger.cpp - Implementation of iteration lattices ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" + +namespace mlir { +namespace sparse_tensor { + +unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { + unsigned e = tensorExps.size(); + tensorExps.push_back(TensorExp(k, e0, e1, v)); + return e; +} + +unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { + assert(t < numTensors && i < numLoops); + unsigned p = latPoints.size(); + latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); + return p; +} + +unsigned Merger::addSet() { + unsigned s = latSets.size(); + latSets.emplace_back(SmallVector()); + return s; +} + +unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { + unsigned p = latPoints.size(); + llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); + nb |= latPoints[p1].bits; + unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); + latPoints.push_back(LatPoint(nb, e)); + return p; +} + +unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { + unsigned s = addSet(); + for (unsigned p0 : latSets[s0]) + for (unsigned p1 : latSets[s1]) + latSets[s].push_back(conjLatPoint(kind, p0, p1)); + return s; +} + +unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { + unsigned s = takeConj(kind, s0, s1); + for (unsigned p : latSets[s0]) + latSets[s].push_back(p); + for (unsigned p : latSets[s1]) + latSets[s].push_back(p); + return s; +} + +unsigned Merger::optimizeSet(unsigned s0) { + unsigned s = addSet(); + assert(latSets[s0].size() != 0); + unsigned p0 = latSets[s0][0]; + for (unsigned p1 : latSets[s0]) { + bool add = true; + if (p0 != p1) { + // Is this a straightforward copy? + unsigned e = latPoints[p1].exp; + if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) + continue; + // Conjunction already covered? + for (unsigned p2 : latSets[s]) { + assert(!latGT(p1, p2)); // Lj => Li would be bad + if (onlyDenseDiff(p2, p1)) { + add = false; + break; + } + } + assert(!add || latGT(p0, p1)); + } + if (add) + latSets[s].push_back(p1); + } + for (unsigned p : latSets[s]) + latPoints[p].simple = simplifyCond(s, p); + return s; +} + +llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) { + // First determine if this lattice point is a *singleton*, i.e., + // the last point in a lattice, no other is less than this one. + bool isSingleton = true; + for (unsigned p1 : latSets[s]) { + if (p0 != p1 && latGT(p0, p1)) { + isSingleton = false; + break; + } + } + // Now apply the two basic rules. + llvm::BitVector simple = latPoints[p0].bits; + bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); + for (unsigned b = 0, be = simple.size(); b < be; b++) { + if (simple[b] && !isDim(b, Dim::kSparse)) { + if (reset) + simple.reset(b); + reset = true; + } + } + return simple; +} + +bool Merger::latGT(unsigned i, unsigned j) const { + const llvm::BitVector &bitsi = latPoints[i].bits; + const llvm::BitVector &bitsj = latPoints[j].bits; + assert(bitsi.size() == bitsj.size()); + if (bitsi.count() > bitsj.count()) { + for (unsigned b = 0, be = bitsj.size(); b < be; b++) + if (bitsj[b] && !bitsi[b]) + return false; + return true; + } + return false; +} + +bool Merger::onlyDenseDiff(unsigned i, unsigned j) { + llvm::BitVector tmp = latPoints[j].bits; + tmp ^= latPoints[i].bits; + return !hasAnyDimOf(tmp, Dim::kSparse); +} + +bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { + for (unsigned b = 0, be = bits.size(); b < be; b++) + if (bits[b] && isDim(b, d)) + return true; + return false; +} + +} // namespace sparse_tensor +} // namespace mlir