[mlir][sparse] Strengthening first arguments of fromCOO/toCOO

Better capturing of invariants

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D116700
This commit is contained in:
wren romano 2022-01-05 13:46:15 -08:00
parent c03fd1e61f
commit ceda1ae9a7
1 changed files with 9 additions and 9 deletions

View File

@ -269,9 +269,10 @@ public:
pointers[r].push_back(0);
// Then assign contents from coordinate scheme tensor if provided.
if (tensor) {
uint64_t nnz = tensor->getElements().size();
const std::vector<Element<V>> &elements = tensor->getElements();
uint64_t nnz = elements.size();
values.reserve(nnz);
fromCOO(tensor, 0, nnz, 0);
fromCOO(elements, 0, nnz, 0);
} else if (allDense) {
values.resize(sz, 0);
}
@ -367,7 +368,7 @@ public:
std::vector<uint64_t> reord(rank);
for (uint64_t r = 0; r < rank; r++)
reord[r] = perm[rev[r]];
toCOO(tensor, reord, 0, 0);
toCOO(*tensor, reord, 0, 0);
assert(tensor->getElements().size() == values.size());
return tensor;
}
@ -402,9 +403,8 @@ private:
/// Initializes sparse tensor storage scheme from a memory-resident sparse
/// tensor in coordinate scheme. This method prepares the pointers and
/// indices arrays under the given per-dimension dense/sparse annotations.
void fromCOO(SparseTensorCOO<V> *tensor, uint64_t lo, uint64_t hi,
uint64_t d) {
const std::vector<Element<V>> &elements = tensor->getElements();
void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
uint64_t hi, uint64_t d) {
// Once dimensions are exhausted, insert the numerical values.
assert(d <= getRank());
if (d == getRank()) {
@ -432,7 +432,7 @@ private:
endDim(d + 1);
full++;
}
fromCOO(tensor, lo, seg, d + 1);
fromCOO(elements, lo, seg, d + 1);
// And move on to next segment in interval.
lo = seg;
}
@ -449,12 +449,12 @@ private:
/// Stores the sparse tensor storage scheme into a memory-resident sparse
/// tensor in coordinate scheme.
void toCOO(SparseTensorCOO<V> *tensor, std::vector<uint64_t> &reord,
void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
uint64_t pos, uint64_t d) {
assert(d <= getRank());
if (d == getRank()) {
assert(pos < values.size());
tensor->add(idx, values[pos]);
tensor.add(idx, values[pos]);
} else if (isCompressedDim(d)) {
// Sparse dimension.
for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {