forked from OSchip/llvm-project
[mlir][sparse] Factoring out an enumerator over elements of SparseTensorStorage
Work towards fixing: https://github.com/llvm/llvm-project/issues/51652 Depends On D122928 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D122060
This commit is contained in:
parent
8b9caad8eb
commit
753fe330c1
|
@ -27,6 +27,7 @@
|
|||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
@ -94,6 +95,13 @@ struct Element {
|
|||
V value;
|
||||
};
|
||||
|
||||
/// The type of callback functions which receive an element. We avoid
|
||||
/// packaging the coordinates and value together as an `Element` object
|
||||
/// because this helps keep code somewhat cleaner.
|
||||
template <typename V>
|
||||
using ElementConsumer =
|
||||
const std::function<void(const std::vector<uint64_t> &, V)> &;
|
||||
|
||||
/// A memory-resident sparse tensor in coordinate scheme (collection of
|
||||
/// elements). This data structure is used to read a sparse tensor from
|
||||
/// any external format into memory and sort the elements lexicographically
|
||||
|
@ -220,6 +228,7 @@ public:
|
|||
const uint64_t *perm, const DimLevelType *sparsity)
|
||||
: dimSizes(szs), rev(getRank()),
|
||||
dimTypes(sparsity, sparsity + getRank()) {
|
||||
assert(perm && sparsity);
|
||||
const uint64_t rank = getRank();
|
||||
// Validate parameters.
|
||||
assert(rank > 0 && "Trivial shape is unsupported");
|
||||
|
@ -310,6 +319,16 @@ public:
|
|||
/// Finishes insertion.
|
||||
virtual void endInsert() = 0;
|
||||
|
||||
protected:
|
||||
// Since this class is virtual, we must disallow public copying in
|
||||
// order to avoid "slicing". Since this class has data members,
|
||||
// that means making copying protected.
|
||||
// <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
|
||||
SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
|
||||
// Copy-assignment would be implicitly deleted (because `dimSizes`
|
||||
// is const), so we explicitly delete it for clarity.
|
||||
SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
|
||||
|
||||
private:
|
||||
static void fatal(const char *tp) {
|
||||
fprintf(stderr, "unsupported %s\n", tp);
|
||||
|
@ -321,6 +340,10 @@ private:
|
|||
const std::vector<DimLevelType> dimTypes;
|
||||
};
|
||||
|
||||
// Forward.
|
||||
template <typename P, typename I, typename V>
|
||||
class SparseTensorEnumerator;
|
||||
|
||||
/// A memory-resident sparse tensor using a storage scheme based on
|
||||
/// per-dimension sparse/dense annotations. This data structure provides a
|
||||
/// bufferized form of a sparse tensor type. In contrast to generating setup
|
||||
|
@ -443,24 +466,13 @@ public:
|
|||
/// sparse tensor in coordinate scheme with the given dimension order.
|
||||
///
|
||||
/// Precondition: `perm` must be valid for `getRank()`.
|
||||
SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
|
||||
// Restore original order of the dimension sizes and allocate coordinate
|
||||
// scheme with desired new ordering specified in perm.
|
||||
const uint64_t rank = getRank();
|
||||
const auto &rev = getRev();
|
||||
const auto &sizes = getDimSizes();
|
||||
std::vector<uint64_t> orgsz(rank);
|
||||
for (uint64_t r = 0; r < rank; r++)
|
||||
orgsz[rev[r]] = sizes[r];
|
||||
SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO(
|
||||
rank, orgsz.data(), perm, values.size());
|
||||
// Populate coordinate scheme restored from old ordering and changed with
|
||||
// new ordering. Rather than applying both reorderings during the recursion,
|
||||
// we compute the combine permutation in advance.
|
||||
std::vector<uint64_t> reord(rank);
|
||||
for (uint64_t r = 0; r < rank; r++)
|
||||
reord[r] = perm[rev[r]];
|
||||
toCOO(*coo, reord, 0, 0);
|
||||
SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
|
||||
SparseTensorEnumerator<P, I, V> enumerator(*this, getRank(), perm);
|
||||
SparseTensorCOO<V> *coo =
|
||||
new SparseTensorCOO<V>(enumerator.permutedSizes(), values.size());
|
||||
enumerator.forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
|
||||
coo->add(ind, val);
|
||||
});
|
||||
// TODO: This assertion assumes there are no stored zeros,
|
||||
// or if there are then that we don't filter them out.
|
||||
// Cf., <https://github.com/llvm/llvm-project/issues/54179>
|
||||
|
@ -543,9 +555,10 @@ private:
|
|||
/// and pointwise less-than).
|
||||
void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
|
||||
uint64_t hi, uint64_t d) {
|
||||
uint64_t rank = getRank();
|
||||
assert(d <= rank && hi <= elements.size());
|
||||
// Once dimensions are exhausted, insert the numerical values.
|
||||
assert(d <= getRank() && hi <= elements.size());
|
||||
if (d == getRank()) {
|
||||
if (d == rank) {
|
||||
assert(lo < hi);
|
||||
values.push_back(elements[lo].value);
|
||||
return;
|
||||
|
@ -569,31 +582,6 @@ private:
|
|||
finalizeSegment(d, full);
|
||||
}
|
||||
|
||||
/// Stores the sparse tensor storage scheme into a memory-resident sparse
|
||||
/// tensor in coordinate scheme.
|
||||
void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
|
||||
uint64_t pos, uint64_t d) {
|
||||
assert(d <= getRank());
|
||||
if (d == getRank()) {
|
||||
assert(pos < values.size());
|
||||
tensor.add(idx, values[pos]);
|
||||
} else if (isCompressedDim(d)) {
|
||||
// Sparse dimension.
|
||||
for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
|
||||
idx[reord[d]] = indices[d][ii];
|
||||
toCOO(tensor, reord, ii, d + 1);
|
||||
}
|
||||
} else {
|
||||
// Dense dimension.
|
||||
const uint64_t sz = getDimSizes()[d];
|
||||
const uint64_t off = pos * sz;
|
||||
for (uint64_t i = 0; i < sz; i++) {
|
||||
idx[reord[d]] = i;
|
||||
toCOO(tensor, reord, off + i, d + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize the sparse pointer structure at this dimension.
|
||||
void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
|
||||
if (count == 0)
|
||||
|
@ -649,13 +637,151 @@ private:
|
|||
return -1u;
|
||||
}
|
||||
|
||||
private:
|
||||
// Allow `SparseTensorEnumerator` to access the data-members (to avoid
|
||||
// the cost of virtual-function dispatch in inner loops), without
|
||||
// making them public to other client code.
|
||||
friend class SparseTensorEnumerator<P, I, V>;
|
||||
|
||||
std::vector<std::vector<P>> pointers;
|
||||
std::vector<std::vector<I>> indices;
|
||||
std::vector<V> values;
|
||||
std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
|
||||
};
|
||||
|
||||
/// A (higher-order) function object for enumerating the elements of some
|
||||
/// `SparseTensorStorage` under a permutation. That is, the `forallElements`
|
||||
/// method encapsulates the loop-nest for enumerating the elements of
|
||||
/// the source tensor (in whatever order is best for the source tensor),
|
||||
/// and applies a permutation to the coordinates/indices before handing
|
||||
/// each element to the callback. A single enumerator object can be
|
||||
/// freely reused for several calls to `forallElements`, just so long
|
||||
/// as each call is sequential with respect to one another.
|
||||
///
|
||||
/// N.B., this class stores a reference to the `SparseTensorStorageBase`
|
||||
/// passed to the constructor; thus, objects of this class must not
|
||||
/// outlive the sparse tensor they depend on.
|
||||
///
|
||||
/// Design Note: The reason we define this class instead of simply using
|
||||
/// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
|
||||
/// the `<P,I>` template parameters from MLIR client code (to simplify the
|
||||
/// type parameters used for direct sparse-to-sparse conversion). And the
|
||||
/// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
|
||||
/// than simply using this class, is to avoid the cost of virtual-method
|
||||
/// dispatch within the loop-nest.
|
||||
template <typename V>
|
||||
class SparseTensorEnumeratorBase {
|
||||
public:
|
||||
/// Constructs an enumerator with the given permutation for mapping
|
||||
/// the semantic-ordering of dimensions to the desired target-ordering.
|
||||
///
|
||||
/// Preconditions:
|
||||
/// * the `tensor` must have the same `V` value type.
|
||||
/// * `perm` must be valid for `rank`.
|
||||
SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
|
||||
uint64_t rank, const uint64_t *perm)
|
||||
: src(tensor), permsz(src.getRev().size()), reord(getRank()),
|
||||
cursor(getRank()) {
|
||||
assert(perm && "Received nullptr for permutation");
|
||||
assert(rank == getRank() && "Permutation rank mismatch");
|
||||
const auto &rev = src.getRev(); // source stg-order -> semantic-order
|
||||
const auto &sizes = src.getDimSizes(); // in source storage-order
|
||||
for (uint64_t s = 0; s < rank; s++) { // `s` source storage-order
|
||||
uint64_t t = perm[rev[s]]; // `t` target-order
|
||||
reord[s] = t;
|
||||
permsz[t] = sizes[s];
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~SparseTensorEnumeratorBase() = default;
|
||||
|
||||
// We disallow copying to help avoid leaking the `src` reference.
|
||||
// (In addition to avoiding the problem of slicing.)
|
||||
SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
|
||||
SparseTensorEnumeratorBase &
|
||||
operator=(const SparseTensorEnumeratorBase &) = delete;
|
||||
|
||||
/// Returns the source/target tensor's rank. (The source-rank and
|
||||
/// target-rank are always equal since we only support permutations.
|
||||
/// Though once we add support for other dimension mappings, this
|
||||
/// method will have to be split in two.)
|
||||
uint64_t getRank() const { return permsz.size(); }
|
||||
|
||||
/// Returns the target tensor's dimension sizes.
|
||||
const std::vector<uint64_t> &permutedSizes() const { return permsz; }
|
||||
|
||||
/// Enumerates all elements of the source tensor, permutes their
|
||||
/// indices, and passes the permuted element to the callback.
|
||||
/// The callback must not store the cursor reference directly,
|
||||
/// since this function reuses the storage. Instead, the callback
|
||||
/// must copy it if they want to keep it.
|
||||
virtual void forallElements(ElementConsumer<V> yield) = 0;
|
||||
|
||||
protected:
|
||||
const SparseTensorStorageBase &src;
|
||||
std::vector<uint64_t> permsz; // in target order.
|
||||
std::vector<uint64_t> reord; // source storage-order -> target order.
|
||||
std::vector<uint64_t> cursor; // in target order.
|
||||
};
|
||||
|
||||
template <typename P, typename I, typename V>
|
||||
class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
|
||||
using Base = SparseTensorEnumeratorBase<V>;
|
||||
|
||||
public:
|
||||
/// Constructs an enumerator with the given permutation for mapping
|
||||
/// the semantic-ordering of dimensions to the desired target-ordering.
|
||||
///
|
||||
/// Precondition: `perm` must be valid for `rank`.
|
||||
SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
|
||||
uint64_t rank, const uint64_t *perm)
|
||||
: Base(tensor, rank, perm) {}
|
||||
|
||||
~SparseTensorEnumerator() final override = default;
|
||||
|
||||
void forallElements(ElementConsumer<V> yield) final override {
|
||||
forallElements(yield, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The recursive component of the public `forallElements`.
|
||||
void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
|
||||
uint64_t d) {
|
||||
// Recover the `<P,I,V>` type parameters of `src`.
|
||||
const auto &src =
|
||||
static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
|
||||
if (d == Base::getRank()) {
|
||||
assert(parentPos < src.values.size() &&
|
||||
"Value position is out of bounds");
|
||||
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
|
||||
yield(this->cursor, src.values[parentPos]);
|
||||
} else if (src.isCompressedDim(d)) {
|
||||
// Look up the bounds of the `d`-level segment determined by the
|
||||
// `d-1`-level position `parentPos`.
|
||||
const std::vector<P> &pointers_d = src.pointers[d];
|
||||
assert(parentPos + 1 < pointers_d.size() &&
|
||||
"Parent pointer position is out of bounds");
|
||||
const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
|
||||
const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
|
||||
// Loop-invariant code for looking up the `d`-level coordinates/indices.
|
||||
const std::vector<I> &indices_d = src.indices[d];
|
||||
assert(pstop - 1 < indices_d.size() && "Index position is out of bounds");
|
||||
uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
|
||||
for (uint64_t pos = pstart; pos < pstop; pos++) {
|
||||
cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
|
||||
forallElements(yield, pos, d + 1);
|
||||
}
|
||||
} else { // Dense dimension.
|
||||
const uint64_t sz = src.getDimSizes()[d];
|
||||
const uint64_t pstart = parentPos * sz;
|
||||
uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
|
||||
for (uint64_t i = 0; i < sz; i++) {
|
||||
cursor_reord_d = i;
|
||||
forallElements(yield, pstart + i, d + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper to convert string to lower case.
|
||||
static char *toLower(char *token) {
|
||||
for (char *c = token; *c; c++)
|
||||
|
|
Loading…
Reference in New Issue