[MLIR][Sparse] Refactor lattice code into its own file

Moves iteration lattice/merger code into new SparseTensor/Utils directory. A follow-up CL will add lattice/merger unit tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D104757
This commit is contained in:
Gus Smith 2021-06-24 22:18:40 +00:00
parent b2787945f9
commit 744146f60b
6 changed files with 313 additions and 239 deletions

View File

@ -0,0 +1,163 @@
//===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines utilities for dealing with iteration lattices.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
#include "mlir/IR/Value.h"
#include "llvm/ADT/BitVector.h"
namespace mlir {
namespace sparse_tensor {
enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
enum class Dim { kSparse, kDense, kSingle, kUndef };
/// Tensor expression. Represents a MLIR expression in tensor index notation.
/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
/// stored directly. For binary operations, e0 and e1 denote the index of the
/// children tensor expressions.
struct TensorExp {
TensorExp(Kind k, unsigned x, unsigned y, Value v)
: kind(k), e0(x), e1(y), val(v) {
assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
(kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
(kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
}
Kind kind;
/// Indices of children expression(s).
unsigned e0;
unsigned e1;
/// Direct link to IR for an invariant. During code generation,
/// field is used to cache "hoisted" loop invariant tensor loads.
Value val;
};
/// Lattice point. Each lattice point consists of a conjunction of tensor
/// loop indices (encoded in a bitvector) and the index of the corresponding
/// tensor expression.
struct LatPoint {
LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
bits.set(b);
}
LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
/// Conjunction of tensor loop indices as bitvector. This represents
/// all indices involved in the tensor expression
llvm::BitVector bits;
/// Simplified conjunction of tensor loop indices as bitvector. This
/// represents a simplified condition under which this tensor expression
/// must execute. Pre-computed during codegen to avoid repeated eval.
llvm::BitVector simple;
/// Index of the tensor expresssion.
unsigned exp;
};
/// A class to handle all iteration lattice operations. This class abstracts
/// away from some implementation details of storing iteration lattices and
/// tensor expressions. This allows for fine-tuning performance characteristics
/// independently from the basic algorithm if bottlenecks are identified.
class Merger {
public:
/// Constructs a merger for the given number of tensors and loops. The
/// user supplies the number of tensors involved in the kernel, with the
/// last tensor in this set denoting the output tensor. The merger adds an
/// additional synthetic tensor at the end of this set to represent all
/// invariant expressions in the kernel.
Merger(unsigned t, unsigned l)
: outTensor(t - 1), numTensors(t + 1), numLoops(l),
dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
/// Adds a tensor expression. Returns its index.
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
/// Adds an iteration lattice point. Returns its index.
unsigned addLat(unsigned t, unsigned i, unsigned e);
/// Adds a new, initially empty, set. Returns its index.
unsigned addSet();
/// Computes a single conjunction of two lattice points by taking the "union"
/// of loop indices (effectively constructing a larger "intersection" of those
/// indices) with a newly constructed tensor (sub)expression of given kind.
/// Returns the index of the new lattice point.
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1);
/// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
/// cartesian product. Returns the index of the new set.
unsigned takeConj(Kind kind, unsigned s0, unsigned s1);
/// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
/// Returns the index of the new set.
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
/// Optimizes the iteration lattice points in the given set. This
/// method should be called right before code generation to avoid
/// generating redundant loops and conditions.
unsigned optimizeSet(unsigned s0);
/// Simplifies the conditions in a conjunction of a given lattice point
/// within the given set using just two basic rules:
/// (1) multiple dense conditions are reduced to single dense, and
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
llvm::BitVector simplifyCond(unsigned s, unsigned p0);
/// Returns true if Li > Lj.
bool latGT(unsigned i, unsigned j) const;
/// Returns true if Li and Lj only differ in dense.
bool onlyDenseDiff(unsigned i, unsigned j);
/// Bit translation.
unsigned tensor(unsigned b) const { return b % numTensors; }
unsigned index(unsigned b) const { return b / numTensors; }
/// Returns true if bit corresponds to queried dim.
bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
/// Returns true if bit corresponds to index of output tensor.
bool isOutTensor(unsigned b, unsigned i) const {
return tensor(b) == outTensor && index(b) == i;
}
/// Returns true if tensor access at given index has queried dim.
bool isDim(unsigned t, unsigned i, Dim d) const {
assert(t < numTensors && i < numLoops);
return dims[t][i] == d;
}
/// Returns true if any set bit corresponds to queried dim.
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
/// Setter
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
/// Getters.
TensorExp &exp(unsigned e) { return tensorExps[e]; }
LatPoint &lat(unsigned l) { return latPoints[l]; }
SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
private:
const unsigned outTensor;
const unsigned numTensors;
const unsigned numLoops;
std::vector<std::vector<Dim>> dims;
llvm::SmallVector<TensorExp, 32> tensorExps;
llvm::SmallVector<LatPoint, 16> latPoints;
llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
};
} // namespace sparse_tensor
} // namespace mlir
#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_

View File

@ -1,2 +1,3 @@
add_subdirectory(IR) add_subdirectory(IR)
add_subdirectory(Transforms) add_subdirectory(Transforms)
add_subdirectory(Utils)

View File

@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
MLIRSCF MLIRSCF
MLIRStandard MLIRStandard
MLIRSparseTensor MLIRSparseTensor
MLIRSparseTensorUtils
MLIRTensor MLIRTensor
MLIRTransforms MLIRTransforms
MLIRVector MLIRVector

View File

@ -47,6 +47,7 @@
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
@ -58,245 +59,6 @@ using namespace mlir::sparse_tensor;
namespace { namespace {
enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
enum class Dim { kSparse, kDense, kSingle, kUndef };
/// Tensor expression. Represents a MLIR expression in tensor index notation.
/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
/// stored directly. For binary operations, e0 and e1 denote the index of the
/// children tensor expressions.
struct TensorExp {
TensorExp(Kind k, unsigned x, unsigned y, Value v)
: kind(k), e0(x), e1(y), val(v) {
assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
(kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
(kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
}
Kind kind;
/// Indices of children expression(s).
unsigned e0;
unsigned e1;
/// Direct link to IR for an invariant. During code generation,
/// field is used to cache "hoisted" loop invariant tensor loads.
Value val;
};
/// Lattice point. Each lattice point consists of a conjunction of tensor
/// loop indices (encoded in a bitvector) and the index of the corresponding
/// tensor expression.
struct LatPoint {
LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
bits.set(b);
}
LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
/// Conjunction of tensor loop indices as bitvector. This represents
/// all indices involved in the tensor expression
llvm::BitVector bits;
/// Simplified conjunction of tensor loop indices as bitvector. This
/// represents a simplified condition under which this tensor expression
/// must execute. Pre-computed during codegen to avoid repeated eval.
llvm::BitVector simple;
/// Index of the tensor expresssion.
unsigned exp;
};
/// A class to handle all iteration lattice operations. This class abstracts
/// away from some implementation details of storing iteration lattices and
/// tensor expressions. This allows for fine-tuning performance characteristics
/// independently from the basic algorithm if bottlenecks are identified.
class Merger {
public:
/// Constructs a merger for the given number of tensors and loops. The
/// user supplies the number of tensors involved in the kernel, with the
/// last tensor in this set denoting the output tensor. The merger adds an
/// additional synthetic tensor at the end of this set to represent all
/// invariant expressions in the kernel.
Merger(unsigned t, unsigned l)
: outTensor(t - 1), numTensors(t + 1), numLoops(l),
dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
/// Adds a tensor expression. Returns its index.
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
unsigned e = tensorExps.size();
tensorExps.push_back(TensorExp(k, e0, e1, v));
return e;
}
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
/// Adds an iteration lattice point. Returns its index.
unsigned addLat(unsigned t, unsigned i, unsigned e) {
assert(t < numTensors && i < numLoops);
unsigned p = latPoints.size();
latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
return p;
}
/// Adds a new, initially empty, set. Returns its index.
unsigned addSet() {
unsigned s = latSets.size();
latSets.emplace_back(SmallVector<unsigned, 16>());
return s;
}
/// Computes a single conjunction of two lattice points by taking the "union"
/// of loop indices (effectively constructing a larger "intersection" of those
/// indices) with a newly constructed tensor (sub)expression of given kind.
/// Returns the index of the new lattice point.
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
unsigned p = latPoints.size();
llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
nb |= latPoints[p1].bits;
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
latPoints.push_back(LatPoint(nb, e));
return p;
}
/// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
/// cartesian product. Returns the index of the new set.
unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
unsigned s = addSet();
for (unsigned p0 : latSets[s0])
for (unsigned p1 : latSets[s1])
latSets[s].push_back(conjLatPoint(kind, p0, p1));
return s;
}
/// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
/// Returns the index of the new set.
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
unsigned s = takeConj(kind, s0, s1);
for (unsigned p : latSets[s0])
latSets[s].push_back(p);
for (unsigned p : latSets[s1])
latSets[s].push_back(p);
return s;
}
/// Optimizes the iteration lattice points in the given set. This
/// method should be called right before code generation to avoid
/// generating redundant loops and conditions.
unsigned optimizeSet(unsigned s0) {
unsigned s = addSet();
assert(latSets[s0].size() != 0);
unsigned p0 = latSets[s0][0];
for (unsigned p1 : latSets[s0]) {
bool add = true;
if (p0 != p1) {
// Is this a straightforward copy?
unsigned e = latPoints[p1].exp;
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
continue;
// Conjunction already covered?
for (unsigned p2 : latSets[s]) {
assert(!latGT(p1, p2)); // Lj => Li would be bad
if (onlyDenseDiff(p2, p1)) {
add = false;
break;
}
}
assert(!add || latGT(p0, p1));
}
if (add)
latSets[s].push_back(p1);
}
for (unsigned p : latSets[s])
latPoints[p].simple = simplifyCond(s, p);
return s;
}
/// Simplifies the conditions in a conjunction of a given lattice point
/// within the given set using just two basic rules:
/// (1) multiple dense conditions are reduced to single dense, and
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
// First determine if this lattice point is a *singleton*, i.e.,
// the last point in a lattice, no other is less than this one.
bool isSingleton = true;
for (unsigned p1 : latSets[s]) {
if (p0 != p1 && latGT(p0, p1)) {
isSingleton = false;
break;
}
}
// Now apply the two basic rules.
llvm::BitVector simple = latPoints[p0].bits;
bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
for (unsigned b = 0, be = simple.size(); b < be; b++) {
if (simple[b] && !isDim(b, Dim::kSparse)) {
if (reset)
simple.reset(b);
reset = true;
}
}
return simple;
}
/// Returns true if Li > Lj.
bool latGT(unsigned i, unsigned j) const {
const llvm::BitVector &bitsi = latPoints[i].bits;
const llvm::BitVector &bitsj = latPoints[j].bits;
assert(bitsi.size() == bitsj.size());
if (bitsi.count() > bitsj.count()) {
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
if (bitsj[b] && !bitsi[b])
return false;
return true;
}
return false;
}
/// Returns true if Li and Lj only differ in dense.
bool onlyDenseDiff(unsigned i, unsigned j) {
llvm::BitVector tmp = latPoints[j].bits;
tmp ^= latPoints[i].bits;
return !hasAnyDimOf(tmp, Dim::kSparse);
}
/// Bit translation.
unsigned tensor(unsigned b) const { return b % numTensors; }
unsigned index(unsigned b) const { return b / numTensors; }
/// Returns true if bit corresponds to queried dim.
bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
/// Returns true if bit corresponds to index of output tensor.
bool isOutTensor(unsigned b, unsigned i) const {
return tensor(b) == outTensor && index(b) == i;
}
/// Returns true if tensor access at given index has queried dim.
bool isDim(unsigned t, unsigned i, Dim d) const {
assert(t < numTensors && i < numLoops);
return dims[t][i] == d;
}
/// Returns true if any set bit corresponds to queried dim.
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
for (unsigned b = 0, be = bits.size(); b < be; b++)
if (bits[b] && isDim(b, d))
return true;
return false;
}
/// Setter
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
/// Getters.
TensorExp &exp(unsigned e) { return tensorExps[e]; }
LatPoint &lat(unsigned l) { return latPoints[l]; }
SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
private:
const unsigned outTensor;
const unsigned numTensors;
const unsigned numLoops;
std::vector<std::vector<Dim>> dims;
llvm::SmallVector<TensorExp, 32> tensorExps;
llvm::SmallVector<LatPoint, 16> latPoints;
llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
};
// Code generation. // Code generation.
struct CodeGen { struct CodeGen {
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)

View File

@ -0,0 +1,9 @@
add_mlir_dialect_library(MLIRSparseTensorUtils
Merger.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
LINK_LIBS PUBLIC
MLIRIR
)

View File

@ -0,0 +1,138 @@
//===- Merger.cpp - Implementation of iteration lattices ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
namespace mlir {
namespace sparse_tensor {
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
unsigned e = tensorExps.size();
tensorExps.push_back(TensorExp(k, e0, e1, v));
return e;
}
unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
assert(t < numTensors && i < numLoops);
unsigned p = latPoints.size();
latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
return p;
}
unsigned Merger::addSet() {
unsigned s = latSets.size();
latSets.emplace_back(SmallVector<unsigned, 16>());
return s;
}
unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
unsigned p = latPoints.size();
llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
nb |= latPoints[p1].bits;
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
latPoints.push_back(LatPoint(nb, e));
return p;
}
unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
unsigned s = addSet();
for (unsigned p0 : latSets[s0])
for (unsigned p1 : latSets[s1])
latSets[s].push_back(conjLatPoint(kind, p0, p1));
return s;
}
unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
unsigned s = takeConj(kind, s0, s1);
for (unsigned p : latSets[s0])
latSets[s].push_back(p);
for (unsigned p : latSets[s1])
latSets[s].push_back(p);
return s;
}
unsigned Merger::optimizeSet(unsigned s0) {
unsigned s = addSet();
assert(latSets[s0].size() != 0);
unsigned p0 = latSets[s0][0];
for (unsigned p1 : latSets[s0]) {
bool add = true;
if (p0 != p1) {
// Is this a straightforward copy?
unsigned e = latPoints[p1].exp;
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
continue;
// Conjunction already covered?
for (unsigned p2 : latSets[s]) {
assert(!latGT(p1, p2)); // Lj => Li would be bad
if (onlyDenseDiff(p2, p1)) {
add = false;
break;
}
}
assert(!add || latGT(p0, p1));
}
if (add)
latSets[s].push_back(p1);
}
for (unsigned p : latSets[s])
latPoints[p].simple = simplifyCond(s, p);
return s;
}
llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
// First determine if this lattice point is a *singleton*, i.e.,
// the last point in a lattice, no other is less than this one.
bool isSingleton = true;
for (unsigned p1 : latSets[s]) {
if (p0 != p1 && latGT(p0, p1)) {
isSingleton = false;
break;
}
}
// Now apply the two basic rules.
llvm::BitVector simple = latPoints[p0].bits;
bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
for (unsigned b = 0, be = simple.size(); b < be; b++) {
if (simple[b] && !isDim(b, Dim::kSparse)) {
if (reset)
simple.reset(b);
reset = true;
}
}
return simple;
}
bool Merger::latGT(unsigned i, unsigned j) const {
const llvm::BitVector &bitsi = latPoints[i].bits;
const llvm::BitVector &bitsj = latPoints[j].bits;
assert(bitsi.size() == bitsj.size());
if (bitsi.count() > bitsj.count()) {
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
if (bitsj[b] && !bitsi[b])
return false;
return true;
}
return false;
}
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
llvm::BitVector tmp = latPoints[j].bits;
tmp ^= latPoints[i].bits;
return !hasAnyDimOf(tmp, Dim::kSparse);
}
bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
for (unsigned b = 0, be = bits.size(); b < be; b++)
if (bits[b] && isDim(b, d))
return true;
return false;
}
} // namespace sparse_tensor
} // namespace mlir