forked from OSchip/llvm-project
[MLIR][Sparse] Refactor lattice code into its own file
Moves iteration lattice/merger code into new SparseTensor/Utils directory. A follow-up CL will add lattice/merger unit tests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D104757
This commit is contained in:
parent
b2787945f9
commit
744146f60b
|
@ -0,0 +1,163 @@
|
|||
//===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header file defines utilities for dealing with iteration lattices.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
|
||||
#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace sparse_tensor {
|
||||
|
||||
enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
|
||||
enum class Dim { kSparse, kDense, kSingle, kUndef };
|
||||
|
||||
/// Tensor expression. Represents a MLIR expression in tensor index notation.
|
||||
/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
|
||||
/// stored directly. For binary operations, e0 and e1 denote the index of the
|
||||
/// children tensor expressions.
|
||||
struct TensorExp {
|
||||
TensorExp(Kind k, unsigned x, unsigned y, Value v)
|
||||
: kind(k), e0(x), e1(y), val(v) {
|
||||
assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
|
||||
(kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
|
||||
(kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
|
||||
}
|
||||
Kind kind;
|
||||
/// Indices of children expression(s).
|
||||
unsigned e0;
|
||||
unsigned e1;
|
||||
/// Direct link to IR for an invariant. During code generation,
|
||||
/// field is used to cache "hoisted" loop invariant tensor loads.
|
||||
Value val;
|
||||
};
|
||||
|
||||
/// Lattice point. Each lattice point consists of a conjunction of tensor
|
||||
/// loop indices (encoded in a bitvector) and the index of the corresponding
|
||||
/// tensor expression.
|
||||
struct LatPoint {
|
||||
LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
|
||||
bits.set(b);
|
||||
}
|
||||
LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
|
||||
/// Conjunction of tensor loop indices as bitvector. This represents
|
||||
/// all indices involved in the tensor expression
|
||||
llvm::BitVector bits;
|
||||
/// Simplified conjunction of tensor loop indices as bitvector. This
|
||||
/// represents a simplified condition under which this tensor expression
|
||||
/// must execute. Pre-computed during codegen to avoid repeated eval.
|
||||
llvm::BitVector simple;
|
||||
/// Index of the tensor expresssion.
|
||||
unsigned exp;
|
||||
};
|
||||
|
||||
/// A class to handle all iteration lattice operations. This class abstracts
|
||||
/// away from some implementation details of storing iteration lattices and
|
||||
/// tensor expressions. This allows for fine-tuning performance characteristics
|
||||
/// independently from the basic algorithm if bottlenecks are identified.
|
||||
class Merger {
|
||||
public:
|
||||
/// Constructs a merger for the given number of tensors and loops. The
|
||||
/// user supplies the number of tensors involved in the kernel, with the
|
||||
/// last tensor in this set denoting the output tensor. The merger adds an
|
||||
/// additional synthetic tensor at the end of this set to represent all
|
||||
/// invariant expressions in the kernel.
|
||||
Merger(unsigned t, unsigned l)
|
||||
: outTensor(t - 1), numTensors(t + 1), numLoops(l),
|
||||
dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
|
||||
|
||||
/// Adds a tensor expression. Returns its index.
|
||||
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
|
||||
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
|
||||
|
||||
/// Adds an iteration lattice point. Returns its index.
|
||||
unsigned addLat(unsigned t, unsigned i, unsigned e);
|
||||
|
||||
/// Adds a new, initially empty, set. Returns its index.
|
||||
unsigned addSet();
|
||||
|
||||
/// Computes a single conjunction of two lattice points by taking the "union"
|
||||
/// of loop indices (effectively constructing a larger "intersection" of those
|
||||
/// indices) with a newly constructed tensor (sub)expression of given kind.
|
||||
/// Returns the index of the new lattice point.
|
||||
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1);
|
||||
|
||||
/// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
|
||||
/// cartesian product. Returns the index of the new set.
|
||||
unsigned takeConj(Kind kind, unsigned s0, unsigned s1);
|
||||
|
||||
/// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
|
||||
/// Returns the index of the new set.
|
||||
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
|
||||
|
||||
/// Optimizes the iteration lattice points in the given set. This
|
||||
/// method should be called right before code generation to avoid
|
||||
/// generating redundant loops and conditions.
|
||||
unsigned optimizeSet(unsigned s0);
|
||||
|
||||
/// Simplifies the conditions in a conjunction of a given lattice point
|
||||
/// within the given set using just two basic rules:
|
||||
/// (1) multiple dense conditions are reduced to single dense, and
|
||||
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
|
||||
llvm::BitVector simplifyCond(unsigned s, unsigned p0);
|
||||
|
||||
/// Returns true if Li > Lj.
|
||||
bool latGT(unsigned i, unsigned j) const;
|
||||
|
||||
/// Returns true if Li and Lj only differ in dense.
|
||||
bool onlyDenseDiff(unsigned i, unsigned j);
|
||||
|
||||
/// Bit translation.
|
||||
unsigned tensor(unsigned b) const { return b % numTensors; }
|
||||
unsigned index(unsigned b) const { return b / numTensors; }
|
||||
|
||||
/// Returns true if bit corresponds to queried dim.
|
||||
bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
|
||||
|
||||
/// Returns true if bit corresponds to index of output tensor.
|
||||
bool isOutTensor(unsigned b, unsigned i) const {
|
||||
return tensor(b) == outTensor && index(b) == i;
|
||||
}
|
||||
|
||||
/// Returns true if tensor access at given index has queried dim.
|
||||
bool isDim(unsigned t, unsigned i, Dim d) const {
|
||||
assert(t < numTensors && i < numLoops);
|
||||
return dims[t][i] == d;
|
||||
}
|
||||
|
||||
/// Returns true if any set bit corresponds to queried dim.
|
||||
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
|
||||
|
||||
/// Setter
|
||||
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
|
||||
|
||||
/// Getters.
|
||||
TensorExp &exp(unsigned e) { return tensorExps[e]; }
|
||||
LatPoint &lat(unsigned l) { return latPoints[l]; }
|
||||
SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
|
||||
|
||||
private:
|
||||
const unsigned outTensor;
|
||||
const unsigned numTensors;
|
||||
const unsigned numLoops;
|
||||
|
||||
std::vector<std::vector<Dim>> dims;
|
||||
llvm::SmallVector<TensorExp, 32> tensorExps;
|
||||
llvm::SmallVector<LatPoint, 16> latPoints;
|
||||
llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
|
||||
};
|
||||
|
||||
} // namespace sparse_tensor
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
|
|
@ -1,2 +1,3 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Utils)
|
||||
|
|
|
@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
|
|||
MLIRSCF
|
||||
MLIRStandard
|
||||
MLIRSparseTensor
|
||||
MLIRSparseTensorUtils
|
||||
MLIRTensor
|
||||
MLIRTransforms
|
||||
MLIRVector
|
||||
|
|
|
@ -47,6 +47,7 @@
|
|||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -58,245 +59,6 @@ using namespace mlir::sparse_tensor;
|
|||
|
||||
namespace {
|
||||
|
||||
enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
|
||||
enum class Dim { kSparse, kDense, kSingle, kUndef };
|
||||
|
||||
/// Tensor expression. Represents a MLIR expression in tensor index notation.
|
||||
/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
|
||||
/// stored directly. For binary operations, e0 and e1 denote the index of the
|
||||
/// children tensor expressions.
|
||||
struct TensorExp {
|
||||
TensorExp(Kind k, unsigned x, unsigned y, Value v)
|
||||
: kind(k), e0(x), e1(y), val(v) {
|
||||
assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
|
||||
(kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
|
||||
(kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
|
||||
}
|
||||
Kind kind;
|
||||
/// Indices of children expression(s).
|
||||
unsigned e0;
|
||||
unsigned e1;
|
||||
/// Direct link to IR for an invariant. During code generation,
|
||||
/// field is used to cache "hoisted" loop invariant tensor loads.
|
||||
Value val;
|
||||
};
|
||||
|
||||
/// Lattice point. Each lattice point consists of a conjunction of tensor
|
||||
/// loop indices (encoded in a bitvector) and the index of the corresponding
|
||||
/// tensor expression.
|
||||
struct LatPoint {
|
||||
LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
|
||||
bits.set(b);
|
||||
}
|
||||
LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
|
||||
/// Conjunction of tensor loop indices as bitvector. This represents
|
||||
/// all indices involved in the tensor expression
|
||||
llvm::BitVector bits;
|
||||
/// Simplified conjunction of tensor loop indices as bitvector. This
|
||||
/// represents a simplified condition under which this tensor expression
|
||||
/// must execute. Pre-computed during codegen to avoid repeated eval.
|
||||
llvm::BitVector simple;
|
||||
/// Index of the tensor expresssion.
|
||||
unsigned exp;
|
||||
};
|
||||
|
||||
/// A class to handle all iteration lattice operations. This class abstracts
|
||||
/// away from some implementation details of storing iteration lattices and
|
||||
/// tensor expressions. This allows for fine-tuning performance characteristics
|
||||
/// independently from the basic algorithm if bottlenecks are identified.
|
||||
class Merger {
|
||||
public:
|
||||
/// Constructs a merger for the given number of tensors and loops. The
|
||||
/// user supplies the number of tensors involved in the kernel, with the
|
||||
/// last tensor in this set denoting the output tensor. The merger adds an
|
||||
/// additional synthetic tensor at the end of this set to represent all
|
||||
/// invariant expressions in the kernel.
|
||||
Merger(unsigned t, unsigned l)
|
||||
: outTensor(t - 1), numTensors(t + 1), numLoops(l),
|
||||
dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
|
||||
|
||||
/// Adds a tensor expression. Returns its index.
|
||||
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
|
||||
unsigned e = tensorExps.size();
|
||||
tensorExps.push_back(TensorExp(k, e0, e1, v));
|
||||
return e;
|
||||
}
|
||||
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
|
||||
|
||||
/// Adds an iteration lattice point. Returns its index.
|
||||
unsigned addLat(unsigned t, unsigned i, unsigned e) {
|
||||
assert(t < numTensors && i < numLoops);
|
||||
unsigned p = latPoints.size();
|
||||
latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
|
||||
return p;
|
||||
}
|
||||
|
||||
/// Adds a new, initially empty, set. Returns its index.
|
||||
unsigned addSet() {
|
||||
unsigned s = latSets.size();
|
||||
latSets.emplace_back(SmallVector<unsigned, 16>());
|
||||
return s;
|
||||
}
|
||||
|
||||
/// Computes a single conjunction of two lattice points by taking the "union"
|
||||
/// of loop indices (effectively constructing a larger "intersection" of those
|
||||
/// indices) with a newly constructed tensor (sub)expression of given kind.
|
||||
/// Returns the index of the new lattice point.
|
||||
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
|
||||
unsigned p = latPoints.size();
|
||||
llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
|
||||
nb |= latPoints[p1].bits;
|
||||
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
|
||||
latPoints.push_back(LatPoint(nb, e));
|
||||
return p;
|
||||
}
|
||||
|
||||
/// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
|
||||
/// cartesian product. Returns the index of the new set.
|
||||
unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
|
||||
unsigned s = addSet();
|
||||
for (unsigned p0 : latSets[s0])
|
||||
for (unsigned p1 : latSets[s1])
|
||||
latSets[s].push_back(conjLatPoint(kind, p0, p1));
|
||||
return s;
|
||||
}
|
||||
|
||||
/// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
|
||||
/// Returns the index of the new set.
|
||||
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
|
||||
unsigned s = takeConj(kind, s0, s1);
|
||||
for (unsigned p : latSets[s0])
|
||||
latSets[s].push_back(p);
|
||||
for (unsigned p : latSets[s1])
|
||||
latSets[s].push_back(p);
|
||||
return s;
|
||||
}
|
||||
|
||||
/// Optimizes the iteration lattice points in the given set. This
|
||||
/// method should be called right before code generation to avoid
|
||||
/// generating redundant loops and conditions.
|
||||
unsigned optimizeSet(unsigned s0) {
|
||||
unsigned s = addSet();
|
||||
assert(latSets[s0].size() != 0);
|
||||
unsigned p0 = latSets[s0][0];
|
||||
for (unsigned p1 : latSets[s0]) {
|
||||
bool add = true;
|
||||
if (p0 != p1) {
|
||||
// Is this a straightforward copy?
|
||||
unsigned e = latPoints[p1].exp;
|
||||
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
|
||||
continue;
|
||||
// Conjunction already covered?
|
||||
for (unsigned p2 : latSets[s]) {
|
||||
assert(!latGT(p1, p2)); // Lj => Li would be bad
|
||||
if (onlyDenseDiff(p2, p1)) {
|
||||
add = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(!add || latGT(p0, p1));
|
||||
}
|
||||
if (add)
|
||||
latSets[s].push_back(p1);
|
||||
}
|
||||
for (unsigned p : latSets[s])
|
||||
latPoints[p].simple = simplifyCond(s, p);
|
||||
return s;
|
||||
}
|
||||
|
||||
/// Simplifies the conditions in a conjunction of a given lattice point
|
||||
/// within the given set using just two basic rules:
|
||||
/// (1) multiple dense conditions are reduced to single dense, and
|
||||
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
|
||||
llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
|
||||
// First determine if this lattice point is a *singleton*, i.e.,
|
||||
// the last point in a lattice, no other is less than this one.
|
||||
bool isSingleton = true;
|
||||
for (unsigned p1 : latSets[s]) {
|
||||
if (p0 != p1 && latGT(p0, p1)) {
|
||||
isSingleton = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Now apply the two basic rules.
|
||||
llvm::BitVector simple = latPoints[p0].bits;
|
||||
bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
|
||||
for (unsigned b = 0, be = simple.size(); b < be; b++) {
|
||||
if (simple[b] && !isDim(b, Dim::kSparse)) {
|
||||
if (reset)
|
||||
simple.reset(b);
|
||||
reset = true;
|
||||
}
|
||||
}
|
||||
return simple;
|
||||
}
|
||||
|
||||
/// Returns true if Li > Lj.
|
||||
bool latGT(unsigned i, unsigned j) const {
|
||||
const llvm::BitVector &bitsi = latPoints[i].bits;
|
||||
const llvm::BitVector &bitsj = latPoints[j].bits;
|
||||
assert(bitsi.size() == bitsj.size());
|
||||
if (bitsi.count() > bitsj.count()) {
|
||||
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
|
||||
if (bitsj[b] && !bitsi[b])
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if Li and Lj only differ in dense.
|
||||
bool onlyDenseDiff(unsigned i, unsigned j) {
|
||||
llvm::BitVector tmp = latPoints[j].bits;
|
||||
tmp ^= latPoints[i].bits;
|
||||
return !hasAnyDimOf(tmp, Dim::kSparse);
|
||||
}
|
||||
|
||||
/// Bit translation.
|
||||
unsigned tensor(unsigned b) const { return b % numTensors; }
|
||||
unsigned index(unsigned b) const { return b / numTensors; }
|
||||
|
||||
/// Returns true if bit corresponds to queried dim.
|
||||
bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
|
||||
|
||||
/// Returns true if bit corresponds to index of output tensor.
|
||||
bool isOutTensor(unsigned b, unsigned i) const {
|
||||
return tensor(b) == outTensor && index(b) == i;
|
||||
}
|
||||
|
||||
/// Returns true if tensor access at given index has queried dim.
|
||||
bool isDim(unsigned t, unsigned i, Dim d) const {
|
||||
assert(t < numTensors && i < numLoops);
|
||||
return dims[t][i] == d;
|
||||
}
|
||||
|
||||
/// Returns true if any set bit corresponds to queried dim.
|
||||
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
|
||||
for (unsigned b = 0, be = bits.size(); b < be; b++)
|
||||
if (bits[b] && isDim(b, d))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Setter
|
||||
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
|
||||
|
||||
/// Getters.
|
||||
TensorExp &exp(unsigned e) { return tensorExps[e]; }
|
||||
LatPoint &lat(unsigned l) { return latPoints[l]; }
|
||||
SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
|
||||
|
||||
private:
|
||||
const unsigned outTensor;
|
||||
const unsigned numTensors;
|
||||
const unsigned numLoops;
|
||||
|
||||
std::vector<std::vector<Dim>> dims;
|
||||
llvm::SmallVector<TensorExp, 32> tensorExps;
|
||||
llvm::SmallVector<LatPoint, 16> latPoints;
|
||||
llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
|
||||
};
|
||||
|
||||
// Code generation.
|
||||
struct CodeGen {
|
||||
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
add_mlir_dialect_library(MLIRSparseTensorUtils
|
||||
Merger.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
)
|
|
@ -0,0 +1,138 @@
|
|||
//===- Merger.cpp - Implementation of iteration lattices ------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace sparse_tensor {
|
||||
|
||||
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
|
||||
unsigned e = tensorExps.size();
|
||||
tensorExps.push_back(TensorExp(k, e0, e1, v));
|
||||
return e;
|
||||
}
|
||||
|
||||
unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
|
||||
assert(t < numTensors && i < numLoops);
|
||||
unsigned p = latPoints.size();
|
||||
latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
|
||||
return p;
|
||||
}
|
||||
|
||||
unsigned Merger::addSet() {
|
||||
unsigned s = latSets.size();
|
||||
latSets.emplace_back(SmallVector<unsigned, 16>());
|
||||
return s;
|
||||
}
|
||||
|
||||
unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
|
||||
unsigned p = latPoints.size();
|
||||
llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
|
||||
nb |= latPoints[p1].bits;
|
||||
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
|
||||
latPoints.push_back(LatPoint(nb, e));
|
||||
return p;
|
||||
}
|
||||
|
||||
unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
|
||||
unsigned s = addSet();
|
||||
for (unsigned p0 : latSets[s0])
|
||||
for (unsigned p1 : latSets[s1])
|
||||
latSets[s].push_back(conjLatPoint(kind, p0, p1));
|
||||
return s;
|
||||
}
|
||||
|
||||
unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
|
||||
unsigned s = takeConj(kind, s0, s1);
|
||||
for (unsigned p : latSets[s0])
|
||||
latSets[s].push_back(p);
|
||||
for (unsigned p : latSets[s1])
|
||||
latSets[s].push_back(p);
|
||||
return s;
|
||||
}
|
||||
|
||||
unsigned Merger::optimizeSet(unsigned s0) {
|
||||
unsigned s = addSet();
|
||||
assert(latSets[s0].size() != 0);
|
||||
unsigned p0 = latSets[s0][0];
|
||||
for (unsigned p1 : latSets[s0]) {
|
||||
bool add = true;
|
||||
if (p0 != p1) {
|
||||
// Is this a straightforward copy?
|
||||
unsigned e = latPoints[p1].exp;
|
||||
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
|
||||
continue;
|
||||
// Conjunction already covered?
|
||||
for (unsigned p2 : latSets[s]) {
|
||||
assert(!latGT(p1, p2)); // Lj => Li would be bad
|
||||
if (onlyDenseDiff(p2, p1)) {
|
||||
add = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(!add || latGT(p0, p1));
|
||||
}
|
||||
if (add)
|
||||
latSets[s].push_back(p1);
|
||||
}
|
||||
for (unsigned p : latSets[s])
|
||||
latPoints[p].simple = simplifyCond(s, p);
|
||||
return s;
|
||||
}
|
||||
|
||||
llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
|
||||
// First determine if this lattice point is a *singleton*, i.e.,
|
||||
// the last point in a lattice, no other is less than this one.
|
||||
bool isSingleton = true;
|
||||
for (unsigned p1 : latSets[s]) {
|
||||
if (p0 != p1 && latGT(p0, p1)) {
|
||||
isSingleton = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Now apply the two basic rules.
|
||||
llvm::BitVector simple = latPoints[p0].bits;
|
||||
bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
|
||||
for (unsigned b = 0, be = simple.size(); b < be; b++) {
|
||||
if (simple[b] && !isDim(b, Dim::kSparse)) {
|
||||
if (reset)
|
||||
simple.reset(b);
|
||||
reset = true;
|
||||
}
|
||||
}
|
||||
return simple;
|
||||
}
|
||||
|
||||
bool Merger::latGT(unsigned i, unsigned j) const {
|
||||
const llvm::BitVector &bitsi = latPoints[i].bits;
|
||||
const llvm::BitVector &bitsj = latPoints[j].bits;
|
||||
assert(bitsi.size() == bitsj.size());
|
||||
if (bitsi.count() > bitsj.count()) {
|
||||
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
|
||||
if (bitsj[b] && !bitsi[b])
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
|
||||
llvm::BitVector tmp = latPoints[j].bits;
|
||||
tmp ^= latPoints[i].bits;
|
||||
return !hasAnyDimOf(tmp, Dim::kSparse);
|
||||
}
|
||||
|
||||
bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
|
||||
for (unsigned b = 0, be = bits.size(); b < be; b++)
|
||||
if (bits[b] && isDim(b, d))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace sparse_tensor
|
||||
} // namespace mlir
|
Loading…
Reference in New Issue