forked from OSchip/llvm-project
[MLIR][Sparse] Refactor lattice code into its own file
Moves iteration lattice/merger code into new SparseTensor/Utils directory. A follow-up CL will add lattice/merger unit tests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D104757
This commit is contained in:
parent
b2787945f9
commit
744146f60b
|
@ -0,0 +1,163 @@
|
||||||
|
//===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This header file defines utilities for dealing with iteration lattices.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
|
||||||
|
#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "llvm/ADT/BitVector.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace sparse_tensor {
|
||||||
|
|
||||||
|
enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
|
||||||
|
enum class Dim { kSparse, kDense, kSingle, kUndef };
|
||||||
|
|
||||||
|
/// Tensor expression. Represents a MLIR expression in tensor index notation.
|
||||||
|
/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
|
||||||
|
/// stored directly. For binary operations, e0 and e1 denote the index of the
|
||||||
|
/// children tensor expressions.
|
||||||
|
struct TensorExp {
|
||||||
|
TensorExp(Kind k, unsigned x, unsigned y, Value v)
|
||||||
|
: kind(k), e0(x), e1(y), val(v) {
|
||||||
|
assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
|
||||||
|
(kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
|
||||||
|
(kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
|
||||||
|
}
|
||||||
|
Kind kind;
|
||||||
|
/// Indices of children expression(s).
|
||||||
|
unsigned e0;
|
||||||
|
unsigned e1;
|
||||||
|
/// Direct link to IR for an invariant. During code generation,
|
||||||
|
/// field is used to cache "hoisted" loop invariant tensor loads.
|
||||||
|
Value val;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Lattice point. Each lattice point consists of a conjunction of tensor
|
||||||
|
/// loop indices (encoded in a bitvector) and the index of the corresponding
|
||||||
|
/// tensor expression.
|
||||||
|
struct LatPoint {
|
||||||
|
LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
|
||||||
|
bits.set(b);
|
||||||
|
}
|
||||||
|
LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
|
||||||
|
/// Conjunction of tensor loop indices as bitvector. This represents
|
||||||
|
/// all indices involved in the tensor expression
|
||||||
|
llvm::BitVector bits;
|
||||||
|
/// Simplified conjunction of tensor loop indices as bitvector. This
|
||||||
|
/// represents a simplified condition under which this tensor expression
|
||||||
|
/// must execute. Pre-computed during codegen to avoid repeated eval.
|
||||||
|
llvm::BitVector simple;
|
||||||
|
/// Index of the tensor expresssion.
|
||||||
|
unsigned exp;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A class to handle all iteration lattice operations. This class abstracts
|
||||||
|
/// away from some implementation details of storing iteration lattices and
|
||||||
|
/// tensor expressions. This allows for fine-tuning performance characteristics
|
||||||
|
/// independently from the basic algorithm if bottlenecks are identified.
|
||||||
|
class Merger {
|
||||||
|
public:
|
||||||
|
/// Constructs a merger for the given number of tensors and loops. The
|
||||||
|
/// user supplies the number of tensors involved in the kernel, with the
|
||||||
|
/// last tensor in this set denoting the output tensor. The merger adds an
|
||||||
|
/// additional synthetic tensor at the end of this set to represent all
|
||||||
|
/// invariant expressions in the kernel.
|
||||||
|
Merger(unsigned t, unsigned l)
|
||||||
|
: outTensor(t - 1), numTensors(t + 1), numLoops(l),
|
||||||
|
dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
|
||||||
|
|
||||||
|
/// Adds a tensor expression. Returns its index.
|
||||||
|
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
|
||||||
|
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
|
||||||
|
|
||||||
|
/// Adds an iteration lattice point. Returns its index.
|
||||||
|
unsigned addLat(unsigned t, unsigned i, unsigned e);
|
||||||
|
|
||||||
|
/// Adds a new, initially empty, set. Returns its index.
|
||||||
|
unsigned addSet();
|
||||||
|
|
||||||
|
/// Computes a single conjunction of two lattice points by taking the "union"
|
||||||
|
/// of loop indices (effectively constructing a larger "intersection" of those
|
||||||
|
/// indices) with a newly constructed tensor (sub)expression of given kind.
|
||||||
|
/// Returns the index of the new lattice point.
|
||||||
|
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1);
|
||||||
|
|
||||||
|
/// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
|
||||||
|
/// cartesian product. Returns the index of the new set.
|
||||||
|
unsigned takeConj(Kind kind, unsigned s0, unsigned s1);
|
||||||
|
|
||||||
|
/// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
|
||||||
|
/// Returns the index of the new set.
|
||||||
|
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
|
||||||
|
|
||||||
|
/// Optimizes the iteration lattice points in the given set. This
|
||||||
|
/// method should be called right before code generation to avoid
|
||||||
|
/// generating redundant loops and conditions.
|
||||||
|
unsigned optimizeSet(unsigned s0);
|
||||||
|
|
||||||
|
/// Simplifies the conditions in a conjunction of a given lattice point
|
||||||
|
/// within the given set using just two basic rules:
|
||||||
|
/// (1) multiple dense conditions are reduced to single dense, and
|
||||||
|
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
|
||||||
|
llvm::BitVector simplifyCond(unsigned s, unsigned p0);
|
||||||
|
|
||||||
|
/// Returns true if Li > Lj.
|
||||||
|
bool latGT(unsigned i, unsigned j) const;
|
||||||
|
|
||||||
|
/// Returns true if Li and Lj only differ in dense.
|
||||||
|
bool onlyDenseDiff(unsigned i, unsigned j);
|
||||||
|
|
||||||
|
/// Bit translation.
|
||||||
|
unsigned tensor(unsigned b) const { return b % numTensors; }
|
||||||
|
unsigned index(unsigned b) const { return b / numTensors; }
|
||||||
|
|
||||||
|
/// Returns true if bit corresponds to queried dim.
|
||||||
|
bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
|
||||||
|
|
||||||
|
/// Returns true if bit corresponds to index of output tensor.
|
||||||
|
bool isOutTensor(unsigned b, unsigned i) const {
|
||||||
|
return tensor(b) == outTensor && index(b) == i;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if tensor access at given index has queried dim.
|
||||||
|
bool isDim(unsigned t, unsigned i, Dim d) const {
|
||||||
|
assert(t < numTensors && i < numLoops);
|
||||||
|
return dims[t][i] == d;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if any set bit corresponds to queried dim.
|
||||||
|
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
|
||||||
|
|
||||||
|
/// Setter
|
||||||
|
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
|
||||||
|
|
||||||
|
/// Getters.
|
||||||
|
TensorExp &exp(unsigned e) { return tensorExps[e]; }
|
||||||
|
LatPoint &lat(unsigned l) { return latPoints[l]; }
|
||||||
|
SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const unsigned outTensor;
|
||||||
|
const unsigned numTensors;
|
||||||
|
const unsigned numLoops;
|
||||||
|
|
||||||
|
std::vector<std::vector<Dim>> dims;
|
||||||
|
llvm::SmallVector<TensorExp, 32> tensorExps;
|
||||||
|
llvm::SmallVector<LatPoint, 16> latPoints;
|
||||||
|
llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sparse_tensor
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
|
|
@ -1,2 +1,3 @@
|
||||||
add_subdirectory(IR)
|
add_subdirectory(IR)
|
||||||
add_subdirectory(Transforms)
|
add_subdirectory(Transforms)
|
||||||
|
add_subdirectory(Utils)
|
||||||
|
|
|
@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
MLIRSparseTensor
|
MLIRSparseTensor
|
||||||
|
MLIRSparseTensorUtils
|
||||||
MLIRTensor
|
MLIRTensor
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRVector
|
MLIRVector
|
||||||
|
|
|
@ -47,6 +47,7 @@
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
@ -58,245 +59,6 @@ using namespace mlir::sparse_tensor;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
|
|
||||||
enum class Dim { kSparse, kDense, kSingle, kUndef };
|
|
||||||
|
|
||||||
/// Tensor expression. Represents a MLIR expression in tensor index notation.
|
|
||||||
/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
|
|
||||||
/// stored directly. For binary operations, e0 and e1 denote the index of the
|
|
||||||
/// children tensor expressions.
|
|
||||||
struct TensorExp {
|
|
||||||
TensorExp(Kind k, unsigned x, unsigned y, Value v)
|
|
||||||
: kind(k), e0(x), e1(y), val(v) {
|
|
||||||
assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
|
|
||||||
(kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
|
|
||||||
(kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
|
|
||||||
}
|
|
||||||
Kind kind;
|
|
||||||
/// Indices of children expression(s).
|
|
||||||
unsigned e0;
|
|
||||||
unsigned e1;
|
|
||||||
/// Direct link to IR for an invariant. During code generation,
|
|
||||||
/// field is used to cache "hoisted" loop invariant tensor loads.
|
|
||||||
Value val;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Lattice point. Each lattice point consists of a conjunction of tensor
|
|
||||||
/// loop indices (encoded in a bitvector) and the index of the corresponding
|
|
||||||
/// tensor expression.
|
|
||||||
struct LatPoint {
|
|
||||||
LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
|
|
||||||
bits.set(b);
|
|
||||||
}
|
|
||||||
LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
|
|
||||||
/// Conjunction of tensor loop indices as bitvector. This represents
|
|
||||||
/// all indices involved in the tensor expression
|
|
||||||
llvm::BitVector bits;
|
|
||||||
/// Simplified conjunction of tensor loop indices as bitvector. This
|
|
||||||
/// represents a simplified condition under which this tensor expression
|
|
||||||
/// must execute. Pre-computed during codegen to avoid repeated eval.
|
|
||||||
llvm::BitVector simple;
|
|
||||||
/// Index of the tensor expresssion.
|
|
||||||
unsigned exp;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// A class to handle all iteration lattice operations. This class abstracts
|
|
||||||
/// away from some implementation details of storing iteration lattices and
|
|
||||||
/// tensor expressions. This allows for fine-tuning performance characteristics
|
|
||||||
/// independently from the basic algorithm if bottlenecks are identified.
|
|
||||||
class Merger {
|
|
||||||
public:
|
|
||||||
/// Constructs a merger for the given number of tensors and loops. The
|
|
||||||
/// user supplies the number of tensors involved in the kernel, with the
|
|
||||||
/// last tensor in this set denoting the output tensor. The merger adds an
|
|
||||||
/// additional synthetic tensor at the end of this set to represent all
|
|
||||||
/// invariant expressions in the kernel.
|
|
||||||
Merger(unsigned t, unsigned l)
|
|
||||||
: outTensor(t - 1), numTensors(t + 1), numLoops(l),
|
|
||||||
dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
|
|
||||||
|
|
||||||
/// Adds a tensor expression. Returns its index.
|
|
||||||
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
|
|
||||||
unsigned e = tensorExps.size();
|
|
||||||
tensorExps.push_back(TensorExp(k, e0, e1, v));
|
|
||||||
return e;
|
|
||||||
}
|
|
||||||
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
|
|
||||||
|
|
||||||
/// Adds an iteration lattice point. Returns its index.
|
|
||||||
unsigned addLat(unsigned t, unsigned i, unsigned e) {
|
|
||||||
assert(t < numTensors && i < numLoops);
|
|
||||||
unsigned p = latPoints.size();
|
|
||||||
latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
|
|
||||||
return p;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Adds a new, initially empty, set. Returns its index.
|
|
||||||
unsigned addSet() {
|
|
||||||
unsigned s = latSets.size();
|
|
||||||
latSets.emplace_back(SmallVector<unsigned, 16>());
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Computes a single conjunction of two lattice points by taking the "union"
|
|
||||||
/// of loop indices (effectively constructing a larger "intersection" of those
|
|
||||||
/// indices) with a newly constructed tensor (sub)expression of given kind.
|
|
||||||
/// Returns the index of the new lattice point.
|
|
||||||
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
|
|
||||||
unsigned p = latPoints.size();
|
|
||||||
llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
|
|
||||||
nb |= latPoints[p1].bits;
|
|
||||||
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
|
|
||||||
latPoints.push_back(LatPoint(nb, e));
|
|
||||||
return p;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
|
|
||||||
/// cartesian product. Returns the index of the new set.
|
|
||||||
unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
|
|
||||||
unsigned s = addSet();
|
|
||||||
for (unsigned p0 : latSets[s0])
|
|
||||||
for (unsigned p1 : latSets[s1])
|
|
||||||
latSets[s].push_back(conjLatPoint(kind, p0, p1));
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
|
|
||||||
/// Returns the index of the new set.
|
|
||||||
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
|
|
||||||
unsigned s = takeConj(kind, s0, s1);
|
|
||||||
for (unsigned p : latSets[s0])
|
|
||||||
latSets[s].push_back(p);
|
|
||||||
for (unsigned p : latSets[s1])
|
|
||||||
latSets[s].push_back(p);
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Optimizes the iteration lattice points in the given set. This
|
|
||||||
/// method should be called right before code generation to avoid
|
|
||||||
/// generating redundant loops and conditions.
|
|
||||||
unsigned optimizeSet(unsigned s0) {
|
|
||||||
unsigned s = addSet();
|
|
||||||
assert(latSets[s0].size() != 0);
|
|
||||||
unsigned p0 = latSets[s0][0];
|
|
||||||
for (unsigned p1 : latSets[s0]) {
|
|
||||||
bool add = true;
|
|
||||||
if (p0 != p1) {
|
|
||||||
// Is this a straightforward copy?
|
|
||||||
unsigned e = latPoints[p1].exp;
|
|
||||||
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
|
|
||||||
continue;
|
|
||||||
// Conjunction already covered?
|
|
||||||
for (unsigned p2 : latSets[s]) {
|
|
||||||
assert(!latGT(p1, p2)); // Lj => Li would be bad
|
|
||||||
if (onlyDenseDiff(p2, p1)) {
|
|
||||||
add = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert(!add || latGT(p0, p1));
|
|
||||||
}
|
|
||||||
if (add)
|
|
||||||
latSets[s].push_back(p1);
|
|
||||||
}
|
|
||||||
for (unsigned p : latSets[s])
|
|
||||||
latPoints[p].simple = simplifyCond(s, p);
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Simplifies the conditions in a conjunction of a given lattice point
|
|
||||||
/// within the given set using just two basic rules:
|
|
||||||
/// (1) multiple dense conditions are reduced to single dense, and
|
|
||||||
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
|
|
||||||
llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
|
|
||||||
// First determine if this lattice point is a *singleton*, i.e.,
|
|
||||||
// the last point in a lattice, no other is less than this one.
|
|
||||||
bool isSingleton = true;
|
|
||||||
for (unsigned p1 : latSets[s]) {
|
|
||||||
if (p0 != p1 && latGT(p0, p1)) {
|
|
||||||
isSingleton = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Now apply the two basic rules.
|
|
||||||
llvm::BitVector simple = latPoints[p0].bits;
|
|
||||||
bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
|
|
||||||
for (unsigned b = 0, be = simple.size(); b < be; b++) {
|
|
||||||
if (simple[b] && !isDim(b, Dim::kSparse)) {
|
|
||||||
if (reset)
|
|
||||||
simple.reset(b);
|
|
||||||
reset = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return simple;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if Li > Lj.
|
|
||||||
bool latGT(unsigned i, unsigned j) const {
|
|
||||||
const llvm::BitVector &bitsi = latPoints[i].bits;
|
|
||||||
const llvm::BitVector &bitsj = latPoints[j].bits;
|
|
||||||
assert(bitsi.size() == bitsj.size());
|
|
||||||
if (bitsi.count() > bitsj.count()) {
|
|
||||||
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
|
|
||||||
if (bitsj[b] && !bitsi[b])
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if Li and Lj only differ in dense.
|
|
||||||
bool onlyDenseDiff(unsigned i, unsigned j) {
|
|
||||||
llvm::BitVector tmp = latPoints[j].bits;
|
|
||||||
tmp ^= latPoints[i].bits;
|
|
||||||
return !hasAnyDimOf(tmp, Dim::kSparse);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Bit translation.
|
|
||||||
unsigned tensor(unsigned b) const { return b % numTensors; }
|
|
||||||
unsigned index(unsigned b) const { return b / numTensors; }
|
|
||||||
|
|
||||||
/// Returns true if bit corresponds to queried dim.
|
|
||||||
bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
|
|
||||||
|
|
||||||
/// Returns true if bit corresponds to index of output tensor.
|
|
||||||
bool isOutTensor(unsigned b, unsigned i) const {
|
|
||||||
return tensor(b) == outTensor && index(b) == i;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if tensor access at given index has queried dim.
|
|
||||||
bool isDim(unsigned t, unsigned i, Dim d) const {
|
|
||||||
assert(t < numTensors && i < numLoops);
|
|
||||||
return dims[t][i] == d;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if any set bit corresponds to queried dim.
|
|
||||||
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
|
|
||||||
for (unsigned b = 0, be = bits.size(); b < be; b++)
|
|
||||||
if (bits[b] && isDim(b, d))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Setter
|
|
||||||
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
|
|
||||||
|
|
||||||
/// Getters.
|
|
||||||
TensorExp &exp(unsigned e) { return tensorExps[e]; }
|
|
||||||
LatPoint &lat(unsigned l) { return latPoints[l]; }
|
|
||||||
SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
const unsigned outTensor;
|
|
||||||
const unsigned numTensors;
|
|
||||||
const unsigned numLoops;
|
|
||||||
|
|
||||||
std::vector<std::vector<Dim>> dims;
|
|
||||||
llvm::SmallVector<TensorExp, 32> tensorExps;
|
|
||||||
llvm::SmallVector<LatPoint, 16> latPoints;
|
|
||||||
llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Code generation.
|
// Code generation.
|
||||||
struct CodeGen {
|
struct CodeGen {
|
||||||
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
|
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
add_mlir_dialect_library(MLIRSparseTensorUtils
|
||||||
|
Merger.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
)
|
|
@ -0,0 +1,138 @@
|
||||||
|
//===- Merger.cpp - Implementation of iteration lattices ------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace sparse_tensor {
|
||||||
|
|
||||||
|
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
|
||||||
|
unsigned e = tensorExps.size();
|
||||||
|
tensorExps.push_back(TensorExp(k, e0, e1, v));
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
|
||||||
|
assert(t < numTensors && i < numLoops);
|
||||||
|
unsigned p = latPoints.size();
|
||||||
|
latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Merger::addSet() {
|
||||||
|
unsigned s = latSets.size();
|
||||||
|
latSets.emplace_back(SmallVector<unsigned, 16>());
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
|
||||||
|
unsigned p = latPoints.size();
|
||||||
|
llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
|
||||||
|
nb |= latPoints[p1].bits;
|
||||||
|
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
|
||||||
|
latPoints.push_back(LatPoint(nb, e));
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
|
||||||
|
unsigned s = addSet();
|
||||||
|
for (unsigned p0 : latSets[s0])
|
||||||
|
for (unsigned p1 : latSets[s1])
|
||||||
|
latSets[s].push_back(conjLatPoint(kind, p0, p1));
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
|
||||||
|
unsigned s = takeConj(kind, s0, s1);
|
||||||
|
for (unsigned p : latSets[s0])
|
||||||
|
latSets[s].push_back(p);
|
||||||
|
for (unsigned p : latSets[s1])
|
||||||
|
latSets[s].push_back(p);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Merger::optimizeSet(unsigned s0) {
|
||||||
|
unsigned s = addSet();
|
||||||
|
assert(latSets[s0].size() != 0);
|
||||||
|
unsigned p0 = latSets[s0][0];
|
||||||
|
for (unsigned p1 : latSets[s0]) {
|
||||||
|
bool add = true;
|
||||||
|
if (p0 != p1) {
|
||||||
|
// Is this a straightforward copy?
|
||||||
|
unsigned e = latPoints[p1].exp;
|
||||||
|
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
|
||||||
|
continue;
|
||||||
|
// Conjunction already covered?
|
||||||
|
for (unsigned p2 : latSets[s]) {
|
||||||
|
assert(!latGT(p1, p2)); // Lj => Li would be bad
|
||||||
|
if (onlyDenseDiff(p2, p1)) {
|
||||||
|
add = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert(!add || latGT(p0, p1));
|
||||||
|
}
|
||||||
|
if (add)
|
||||||
|
latSets[s].push_back(p1);
|
||||||
|
}
|
||||||
|
for (unsigned p : latSets[s])
|
||||||
|
latPoints[p].simple = simplifyCond(s, p);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
|
||||||
|
// First determine if this lattice point is a *singleton*, i.e.,
|
||||||
|
// the last point in a lattice, no other is less than this one.
|
||||||
|
bool isSingleton = true;
|
||||||
|
for (unsigned p1 : latSets[s]) {
|
||||||
|
if (p0 != p1 && latGT(p0, p1)) {
|
||||||
|
isSingleton = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Now apply the two basic rules.
|
||||||
|
llvm::BitVector simple = latPoints[p0].bits;
|
||||||
|
bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
|
||||||
|
for (unsigned b = 0, be = simple.size(); b < be; b++) {
|
||||||
|
if (simple[b] && !isDim(b, Dim::kSparse)) {
|
||||||
|
if (reset)
|
||||||
|
simple.reset(b);
|
||||||
|
reset = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return simple;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Merger::latGT(unsigned i, unsigned j) const {
|
||||||
|
const llvm::BitVector &bitsi = latPoints[i].bits;
|
||||||
|
const llvm::BitVector &bitsj = latPoints[j].bits;
|
||||||
|
assert(bitsi.size() == bitsj.size());
|
||||||
|
if (bitsi.count() > bitsj.count()) {
|
||||||
|
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
|
||||||
|
if (bitsj[b] && !bitsi[b])
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
|
||||||
|
llvm::BitVector tmp = latPoints[j].bits;
|
||||||
|
tmp ^= latPoints[i].bits;
|
||||||
|
return !hasAnyDimOf(tmp, Dim::kSparse);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
|
||||||
|
for (unsigned b = 0, be = bits.size(); b < be; b++)
|
||||||
|
if (bits[b] && isDim(b, d))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sparse_tensor
|
||||||
|
} // namespace mlir
|
Loading…
Reference in New Issue