forked from OSchip/llvm-project
[mlir][sparse] assert fail on mismatch between rank and annotations array
Rationale: Providing the wrong number of sparse/dense annotations was silently ignored or caused unrelated crashes. This minor change verifies that the provided number matches the rank. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D97034
This commit is contained in:
parent
cd4051ac80
commit
2556d62282
|
@ -76,8 +76,8 @@ public:
|
|||
}
|
||||
/// Adds element as indices and value.
|
||||
void add(const std::vector<uint64_t> &ind, double val) {
|
||||
assert(sizes.size() == ind.size());
|
||||
for (int64_t r = 0, rank = sizes.size(); r < rank; r++)
|
||||
assert(getRank() == ind.size());
|
||||
for (int64_t r = 0, rank = getRank(); r < rank; r++)
|
||||
assert(ind[r] < sizes[r]); // within bounds
|
||||
elements.emplace_back(Element(ind, val));
|
||||
}
|
||||
|
@ -85,6 +85,8 @@ public:
|
|||
void sort() { std::sort(elements.begin(), elements.end(), lexOrder); }
|
||||
/// Primitive one-time iteration.
|
||||
const Element &next() { return elements[pos++]; }
|
||||
/// Returns rank.
|
||||
uint64_t getRank() const { return sizes.size(); }
|
||||
/// Getter for sizes array.
|
||||
const std::vector<uint64_t> &getSizes() const { return sizes; }
|
||||
/// Getter for elements array.
|
||||
|
@ -139,13 +141,13 @@ public:
|
|||
/// Constructs sparse tensor storage scheme following the given
|
||||
/// per-rank dimension dense/sparse annotations.
|
||||
SparseTensorStorage(SparseTensor *tensor, bool *sparsity)
|
||||
: sizes(tensor->getSizes()), pointers(sizes.size()),
|
||||
indices(sizes.size()) {
|
||||
: sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
|
||||
// Provide hints on capacity.
|
||||
// TODO: needs fine-tuning based on sparsity
|
||||
values.reserve(tensor->getElements().size());
|
||||
for (uint64_t d = 0, s = 1, rank = sizes.size(); d < rank; d++) {
|
||||
s *= tensor->getSizes()[d];
|
||||
uint64_t nnz = tensor->getElements().size();
|
||||
values.reserve(nnz);
|
||||
for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) {
|
||||
s *= sizes[d];
|
||||
if (sparsity[d]) {
|
||||
pointers[d].reserve(s + 1);
|
||||
indices[d].reserve(s);
|
||||
|
@ -153,12 +155,16 @@ public:
|
|||
}
|
||||
}
|
||||
// Then setup the tensor.
|
||||
traverse(tensor, sparsity, 0, tensor->getElements().size(), 0);
|
||||
traverse(tensor, sparsity, 0, nnz, 0);
|
||||
}
|
||||
|
||||
virtual ~SparseTensorStorage() {}
|
||||
|
||||
uint64_t getRank() const { return sizes.size(); }
|
||||
|
||||
uint64_t getDimSize(uint64_t d) override { return sizes[d]; }
|
||||
|
||||
// Partially specialize these three methods based on template types.
|
||||
void getPointers(std::vector<P> **out, uint64_t d) override {
|
||||
*out = &pointers[d];
|
||||
}
|
||||
|
@ -176,7 +182,7 @@ private:
|
|||
uint64_t d) {
|
||||
const std::vector<Element> &elements = tensor->getElements();
|
||||
// Once dimensions are exhausted, insert the numerical values.
|
||||
if (d == sizes.size()) {
|
||||
if (d == getRank()) {
|
||||
values.push_back(lo < hi ? elements[lo].value : 0.0);
|
||||
return;
|
||||
}
|
||||
|
@ -221,9 +227,10 @@ private:
|
|||
|
||||
/// Templated reader.
|
||||
template <typename P, typename I, typename V>
|
||||
void *newSparseTensor(char *filename, bool *sparsity) {
|
||||
void *newSparseTensor(char *filename, bool *sparsity, uint64_t size) {
|
||||
uint64_t idata[64];
|
||||
SparseTensor *t = static_cast<SparseTensor *>(openTensorC(filename, idata));
|
||||
assert(size == t->getRank()); // sparsity array must match rank
|
||||
SparseTensorStorageBase *tensor =
|
||||
new SparseTensorStorage<P, I, V>(t, sparsity);
|
||||
delete t;
|
||||
|
@ -481,21 +488,29 @@ void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff,
|
|||
assert(astride == 1);
|
||||
bool *sparsity = abase + aoff;
|
||||
if (ptrTp == kU64 && indTp == kU64 && valTp == kF64)
|
||||
return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity);
|
||||
return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity,
|
||||
asize);
|
||||
if (ptrTp == kU64 && indTp == kU64 && valTp == kF32)
|
||||
return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity);
|
||||
return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity,
|
||||
asize);
|
||||
if (ptrTp == kU64 && indTp == kU32 && valTp == kF64)
|
||||
return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity);
|
||||
return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity,
|
||||
asize);
|
||||
if (ptrTp == kU64 && indTp == kU32 && valTp == kF32)
|
||||
return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity);
|
||||
return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity,
|
||||
asize);
|
||||
if (ptrTp == kU32 && indTp == kU64 && valTp == kF64)
|
||||
return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity);
|
||||
return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity,
|
||||
asize);
|
||||
if (ptrTp == kU32 && indTp == kU64 && valTp == kF32)
|
||||
return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity);
|
||||
return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity,
|
||||
asize);
|
||||
if (ptrTp == kU32 && indTp == kU32 && valTp == kF64)
|
||||
return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity);
|
||||
return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity,
|
||||
asize);
|
||||
if (ptrTp == kU32 && indTp == kU32 && valTp == kF32)
|
||||
return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity);
|
||||
return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity,
|
||||
asize);
|
||||
fputs("unsupported combination of types\n", stderr);
|
||||
exit(1);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue