[mlir][sparse] assert fail on mismatch between rank and annotations array

Rationale:
Providing the wrong number of sparse/dense annotations was silently
ignored or caused unrelated crashes. This minor change verifies that
the provided number matches the rank.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D97034
This commit is contained in:
Aart Bik 2021-02-18 22:01:39 -08:00
parent cd4051ac80
commit 2556d62282
1 changed files with 33 additions and 18 deletions

View File

@ -76,8 +76,8 @@ public:
}
/// Adds element as indices and value.
void add(const std::vector<uint64_t> &ind, double val) {
assert(sizes.size() == ind.size());
for (int64_t r = 0, rank = sizes.size(); r < rank; r++)
assert(getRank() == ind.size());
for (int64_t r = 0, rank = getRank(); r < rank; r++)
assert(ind[r] < sizes[r]); // within bounds
elements.emplace_back(Element(ind, val));
}
@ -85,6 +85,8 @@ public:
void sort() { std::sort(elements.begin(), elements.end(), lexOrder); }
/// Primitive one-time iteration.
const Element &next() { return elements[pos++]; }
/// Returns rank.
uint64_t getRank() const { return sizes.size(); }
/// Getter for sizes array.
const std::vector<uint64_t> &getSizes() const { return sizes; }
/// Getter for elements array.
@ -139,13 +141,13 @@ public:
/// Constructs sparse tensor storage scheme following the given
/// per-rank dimension dense/sparse annotations.
SparseTensorStorage(SparseTensor *tensor, bool *sparsity)
: sizes(tensor->getSizes()), pointers(sizes.size()),
indices(sizes.size()) {
: sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
// Provide hints on capacity.
// TODO: needs fine-tuning based on sparsity
values.reserve(tensor->getElements().size());
for (uint64_t d = 0, s = 1, rank = sizes.size(); d < rank; d++) {
s *= tensor->getSizes()[d];
uint64_t nnz = tensor->getElements().size();
values.reserve(nnz);
for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) {
s *= sizes[d];
if (sparsity[d]) {
pointers[d].reserve(s + 1);
indices[d].reserve(s);
@ -153,12 +155,16 @@ public:
}
}
// Then setup the tensor.
traverse(tensor, sparsity, 0, tensor->getElements().size(), 0);
traverse(tensor, sparsity, 0, nnz, 0);
}
virtual ~SparseTensorStorage() {}
uint64_t getRank() const { return sizes.size(); }
uint64_t getDimSize(uint64_t d) override { return sizes[d]; }
// Partially specialize these three methods based on template types.
void getPointers(std::vector<P> **out, uint64_t d) override {
*out = &pointers[d];
}
@ -176,7 +182,7 @@ private:
uint64_t d) {
const std::vector<Element> &elements = tensor->getElements();
// Once dimensions are exhausted, insert the numerical values.
if (d == sizes.size()) {
if (d == getRank()) {
values.push_back(lo < hi ? elements[lo].value : 0.0);
return;
}
@ -221,9 +227,10 @@ private:
/// Templated reader.
template <typename P, typename I, typename V>
void *newSparseTensor(char *filename, bool *sparsity) {
void *newSparseTensor(char *filename, bool *sparsity, uint64_t size) {
uint64_t idata[64];
SparseTensor *t = static_cast<SparseTensor *>(openTensorC(filename, idata));
assert(size == t->getRank()); // sparsity array must match rank
SparseTensorStorageBase *tensor =
new SparseTensorStorage<P, I, V>(t, sparsity);
delete t;
@ -481,21 +488,29 @@ void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff,
assert(astride == 1);
bool *sparsity = abase + aoff;
if (ptrTp == kU64 && indTp == kU64 && valTp == kF64)
return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity);
return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity,
asize);
if (ptrTp == kU64 && indTp == kU64 && valTp == kF32)
return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity);
return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity,
asize);
if (ptrTp == kU64 && indTp == kU32 && valTp == kF64)
return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity);
return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity,
asize);
if (ptrTp == kU64 && indTp == kU32 && valTp == kF32)
return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity);
return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity,
asize);
if (ptrTp == kU32 && indTp == kU64 && valTp == kF64)
return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity);
return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity,
asize);
if (ptrTp == kU32 && indTp == kU64 && valTp == kF32)
return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity);
return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity,
asize);
if (ptrTp == kU32 && indTp == kU32 && valTp == kF64)
return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity);
return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity,
asize);
if (ptrTp == kU32 && indTp == kU32 && valTp == kF32)
return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity);
return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity,
asize);
fputs("unsupported combination of types\n", stderr);
exit(1);
}