[mlir][sparse] Factoring out an enumerator over elements of SparseTensorStorage

Work towards fixing: https://github.com/llvm/llvm-project/issues/51652

Depends On D122928

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122060
This commit is contained in:
wren romano 2022-05-11 15:58:42 -07:00
parent 8b9caad8eb
commit 753fe330c1
1 changed files with 172 additions and 46 deletions

View File

@ -27,6 +27,7 @@
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <functional>
#include <iostream>
#include <limits>
#include <numeric>
@ -94,6 +95,13 @@ struct Element {
V value;
};
/// The type of callback functions which receive an element. We avoid
/// packaging the coordinates and value together as an `Element` object
/// because this helps keep code somewhat cleaner.
template <typename V>
using ElementConsumer =
const std::function<void(const std::vector<uint64_t> &, V)> &;
/// A memory-resident sparse tensor in coordinate scheme (collection of
/// elements). This data structure is used to read a sparse tensor from
/// any external format into memory and sort the elements lexicographically
@ -220,6 +228,7 @@ public:
const uint64_t *perm, const DimLevelType *sparsity)
: dimSizes(szs), rev(getRank()),
dimTypes(sparsity, sparsity + getRank()) {
assert(perm && sparsity);
const uint64_t rank = getRank();
// Validate parameters.
assert(rank > 0 && "Trivial shape is unsupported");
@ -310,6 +319,16 @@ public:
/// Finishes insertion.
virtual void endInsert() = 0;
protected:
// Since this class is virtual, we must disallow public copying in
// order to avoid "slicing". Since this class has data members,
// that means making copying protected.
// <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
// Copy-assignment would be implicitly deleted (because `dimSizes`
// is const), so we explicitly delete it for clarity.
SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
private:
static void fatal(const char *tp) {
fprintf(stderr, "unsupported %s\n", tp);
@ -321,6 +340,10 @@ private:
const std::vector<DimLevelType> dimTypes;
};
// Forward.
template <typename P, typename I, typename V>
class SparseTensorEnumerator;
/// A memory-resident sparse tensor using a storage scheme based on
/// per-dimension sparse/dense annotations. This data structure provides a
/// bufferized form of a sparse tensor type. In contrast to generating setup
@ -443,24 +466,13 @@ public:
/// sparse tensor in coordinate scheme with the given dimension order.
///
/// Precondition: `perm` must be valid for `getRank()`.
SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
// Restore original order of the dimension sizes and allocate coordinate
// scheme with desired new ordering specified in perm.
const uint64_t rank = getRank();
const auto &rev = getRev();
const auto &sizes = getDimSizes();
std::vector<uint64_t> orgsz(rank);
for (uint64_t r = 0; r < rank; r++)
orgsz[rev[r]] = sizes[r];
SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO(
rank, orgsz.data(), perm, values.size());
// Populate coordinate scheme restored from old ordering and changed with
// new ordering. Rather than applying both reorderings during the recursion,
// we compute the combine permutation in advance.
std::vector<uint64_t> reord(rank);
for (uint64_t r = 0; r < rank; r++)
reord[r] = perm[rev[r]];
toCOO(*coo, reord, 0, 0);
SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
SparseTensorEnumerator<P, I, V> enumerator(*this, getRank(), perm);
SparseTensorCOO<V> *coo =
new SparseTensorCOO<V>(enumerator.permutedSizes(), values.size());
enumerator.forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
coo->add(ind, val);
});
// TODO: This assertion assumes there are no stored zeros,
// or if there are then that we don't filter them out.
// Cf., <https://github.com/llvm/llvm-project/issues/54179>
@ -543,9 +555,10 @@ private:
/// and pointwise less-than).
void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
uint64_t hi, uint64_t d) {
uint64_t rank = getRank();
assert(d <= rank && hi <= elements.size());
// Once dimensions are exhausted, insert the numerical values.
assert(d <= getRank() && hi <= elements.size());
if (d == getRank()) {
if (d == rank) {
assert(lo < hi);
values.push_back(elements[lo].value);
return;
@ -569,31 +582,6 @@ private:
finalizeSegment(d, full);
}
/// Stores the sparse tensor storage scheme into a memory-resident sparse
/// tensor in coordinate scheme.
void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
uint64_t pos, uint64_t d) {
assert(d <= getRank());
if (d == getRank()) {
assert(pos < values.size());
tensor.add(idx, values[pos]);
} else if (isCompressedDim(d)) {
// Sparse dimension.
for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
idx[reord[d]] = indices[d][ii];
toCOO(tensor, reord, ii, d + 1);
}
} else {
// Dense dimension.
const uint64_t sz = getDimSizes()[d];
const uint64_t off = pos * sz;
for (uint64_t i = 0; i < sz; i++) {
idx[reord[d]] = i;
toCOO(tensor, reord, off + i, d + 1);
}
}
}
/// Finalize the sparse pointer structure at this dimension.
void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
if (count == 0)
@ -649,13 +637,151 @@ private:
return -1u;
}
private:
// Allow `SparseTensorEnumerator` to access the data-members (to avoid
// the cost of virtual-function dispatch in inner loops), without
// making them public to other client code.
friend class SparseTensorEnumerator<P, I, V>;
std::vector<std::vector<P>> pointers;
std::vector<std::vector<I>> indices;
std::vector<V> values;
std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
};
/// A (higher-order) function object for enumerating the elements of some
/// `SparseTensorStorage` under a permutation. That is, the `forallElements`
/// method encapsulates the loop-nest for enumerating the elements of
/// the source tensor (in whatever order is best for the source tensor),
/// and applies a permutation to the coordinates/indices before handing
/// each element to the callback. A single enumerator object can be
/// freely reused for several calls to `forallElements`, just so long
/// as each call is sequential with respect to one another.
///
/// N.B., this class stores a reference to the `SparseTensorStorageBase`
/// passed to the constructor; thus, objects of this class must not
/// outlive the sparse tensor they depend on.
///
/// Design Note: The reason we define this class instead of simply using
/// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
/// the `<P,I>` template parameters from MLIR client code (to simplify the
/// type parameters used for direct sparse-to-sparse conversion). And the
/// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
/// than simply using this class, is to avoid the cost of virtual-method
/// dispatch within the loop-nest.
template <typename V>
class SparseTensorEnumeratorBase {
public:
/// Constructs an enumerator with the given permutation for mapping
/// the semantic-ordering of dimensions to the desired target-ordering.
///
/// Preconditions:
/// * the `tensor` must have the same `V` value type.
/// * `perm` must be valid for `rank`.
SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
uint64_t rank, const uint64_t *perm)
: src(tensor), permsz(src.getRev().size()), reord(getRank()),
cursor(getRank()) {
assert(perm && "Received nullptr for permutation");
assert(rank == getRank() && "Permutation rank mismatch");
const auto &rev = src.getRev(); // source stg-order -> semantic-order
const auto &sizes = src.getDimSizes(); // in source storage-order
for (uint64_t s = 0; s < rank; s++) { // `s` source storage-order
uint64_t t = perm[rev[s]]; // `t` target-order
reord[s] = t;
permsz[t] = sizes[s];
}
}
virtual ~SparseTensorEnumeratorBase() = default;
// We disallow copying to help avoid leaking the `src` reference.
// (In addition to avoiding the problem of slicing.)
SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
SparseTensorEnumeratorBase &
operator=(const SparseTensorEnumeratorBase &) = delete;
/// Returns the source/target tensor's rank. (The source-rank and
/// target-rank are always equal since we only support permutations.
/// Though once we add support for other dimension mappings, this
/// method will have to be split in two.)
uint64_t getRank() const { return permsz.size(); }
/// Returns the target tensor's dimension sizes.
const std::vector<uint64_t> &permutedSizes() const { return permsz; }
/// Enumerates all elements of the source tensor, permutes their
/// indices, and passes the permuted element to the callback.
/// The callback must not store the cursor reference directly,
/// since this function reuses the storage. Instead, the callback
/// must copy it if they want to keep it.
virtual void forallElements(ElementConsumer<V> yield) = 0;
protected:
const SparseTensorStorageBase &src;
std::vector<uint64_t> permsz; // in target order.
std::vector<uint64_t> reord; // source storage-order -> target order.
std::vector<uint64_t> cursor; // in target order.
};
template <typename P, typename I, typename V>
class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
using Base = SparseTensorEnumeratorBase<V>;
public:
/// Constructs an enumerator with the given permutation for mapping
/// the semantic-ordering of dimensions to the desired target-ordering.
///
/// Precondition: `perm` must be valid for `rank`.
SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
uint64_t rank, const uint64_t *perm)
: Base(tensor, rank, perm) {}
~SparseTensorEnumerator() final override = default;
void forallElements(ElementConsumer<V> yield) final override {
forallElements(yield, 0, 0);
}
private:
/// The recursive component of the public `forallElements`.
void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
uint64_t d) {
// Recover the `<P,I,V>` type parameters of `src`.
const auto &src =
static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
if (d == Base::getRank()) {
assert(parentPos < src.values.size() &&
"Value position is out of bounds");
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
yield(this->cursor, src.values[parentPos]);
} else if (src.isCompressedDim(d)) {
// Look up the bounds of the `d`-level segment determined by the
// `d-1`-level position `parentPos`.
const std::vector<P> &pointers_d = src.pointers[d];
assert(parentPos + 1 < pointers_d.size() &&
"Parent pointer position is out of bounds");
const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
// Loop-invariant code for looking up the `d`-level coordinates/indices.
const std::vector<I> &indices_d = src.indices[d];
assert(pstop - 1 < indices_d.size() && "Index position is out of bounds");
uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
for (uint64_t pos = pstart; pos < pstop; pos++) {
cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
forallElements(yield, pos, d + 1);
}
} else { // Dense dimension.
const uint64_t sz = src.getDimSizes()[d];
const uint64_t pstart = parentPos * sz;
uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
for (uint64_t i = 0; i < sz; i++) {
cursor_reord_d = i;
forallElements(yield, pstart + i, d + 1);
}
}
}
};
/// Helper to convert string to lower case.
static char *toLower(char *token) {
for (char *c = token; *c; c++)