[mlir][sparse] Restyling macros in the runtime library

In addition to reducing code repetition, this also helps ensure that the various API functions follow the naming convention of mlir::sparse_tensor::primaryTypeFunctionSuffix (e.g., due to typos in the repetitious code).

Depends On D125428

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D125431
This commit is contained in:
wren romano 2022-05-11 16:32:54 -07:00
parent 7dbf2e7b57
commit 1313f5d307
1 changed files with 140 additions and 306 deletions

View File

@ -234,6 +234,29 @@ private:
unsigned iteratorPos = 0;
};
// See <https://en.wikipedia.org/wiki/X_Macro>
//
// `FOREVERY_SIMPLEX_V` only specifies the non-complex `V` types, because
// the ABI for complex types has compiler/architecture dependent complexities
// we need to work around. Namely, when a function takes a parameter of
// C/C++ type `complex32` (per se), then there is additional padding that
// causes it not to match the LLVM type `!llvm.struct<(f32, f32)>`. This
// only happens with the `complex32` type itself, not with pointers/arrays
// of complex values. So far `complex64` doesn't exhibit this ABI
// incompatibility, but we exclude it anyways just to be safe.
#define FOREVERY_SIMPLEX_V(DO) \
DO(F64, double) \
DO(F32, float) \
DO(I64, int64_t) \
DO(I32, int32_t) \
DO(I16, int16_t) \
DO(I8, int8_t)
#define FOREVERY_V(DO) \
FOREVERY_SIMPLEX_V(DO) \
DO(C64, complex64) \
DO(C32, complex32)
// Forward.
template <typename V>
class SparseTensorEnumeratorBase;
@ -298,38 +321,13 @@ public:
}
/// Allocate a new enumerator.
virtual void newEnumerator(SparseTensorEnumeratorBase<double> **, uint64_t,
const uint64_t *) const {
fatal("enumf64");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<float> **, uint64_t,
const uint64_t *) const {
fatal("enumf32");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<int64_t> **, uint64_t,
const uint64_t *) const {
fatal("enumi64");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<int32_t> **, uint64_t,
const uint64_t *) const {
fatal("enumi32");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<int16_t> **, uint64_t,
const uint64_t *) const {
fatal("enumi16");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<int8_t> **, uint64_t,
const uint64_t *) const {
fatal("enumi8");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<complex64> **, uint64_t,
const uint64_t *) const {
fatal("enumc64");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<complex32> **, uint64_t,
const uint64_t *) const {
fatal("enumc32");
#define DECL_NEWENUMERATOR(VNAME, V) \
virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
const uint64_t *) const { \
fatal("newEnumerator" #VNAME); \
}
FOREVERY_V(DECL_NEWENUMERATOR)
#undef DECL_NEWENUMERATOR
/// Overhead storage.
virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
@ -342,52 +340,24 @@ public:
virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
/// Primary storage.
virtual void getValues(std::vector<double> **) { fatal("valf64"); }
virtual void getValues(std::vector<float> **) { fatal("valf32"); }
virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
virtual void getValues(std::vector<complex64> **) { fatal("valc64"); }
virtual void getValues(std::vector<complex32> **) { fatal("valc32"); }
#define DECL_GETVALUES(VNAME, V) \
virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
FOREVERY_V(DECL_GETVALUES)
#undef DECL_GETVALUES
/// Element-wise insertion in lexicographic index order.
virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
virtual void lexInsert(const uint64_t *, complex64) { fatal("insc64"); }
virtual void lexInsert(const uint64_t *, complex32) { fatal("insc32"); }
#define DECL_LEXINSERT(VNAME, V) \
virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
FOREVERY_V(DECL_LEXINSERT)
#undef DECL_LEXINSERT
/// Expanded insertion.
virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
fatal("expf64");
}
virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
fatal("expf32");
}
virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
fatal("expi64");
}
virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
fatal("expi32");
}
virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
fatal("expi16");
}
virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
fatal("expi8");
}
virtual void expInsert(uint64_t *, complex64 *, bool *, uint64_t *,
uint64_t) {
fatal("expc64");
}
virtual void expInsert(uint64_t *, complex32 *, bool *, uint64_t *,
uint64_t) {
fatal("expc32");
#define DECL_EXPINSERT(VNAME, V) \
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
fatal("expInsert" #VNAME); \
}
FOREVERY_V(DECL_EXPINSERT)
#undef DECL_EXPINSERT
/// Finishes insertion.
virtual void endInsert() = 0;
@ -1440,17 +1410,23 @@ extern "C" {
}
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
// TODO(D125432): move `_mlir_ciface_newSparseTensor` closer to these
// macro definitions, but as a separate change so as not to muddy the diff.
#define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \
/// Methods that provide direct access to values.
#define IMPL_SPARSEVALUES(VNAME, V) \
void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \
void *tensor) { \
assert(ref &&tensor); \
std::vector<TYPE> *v; \
static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \
std::vector<V> *v; \
static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
ref->basePtr = ref->data = v->data(); \
ref->offset = 0; \
ref->sizes[0] = v->size(); \
ref->strides[0] = 1; \
}
FOREVERY_V(IMPL_SPARSEVALUES)
#undef IMPL_SPARSEVALUES
#define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
@ -1463,12 +1439,27 @@ extern "C" {
ref->sizes[0] = v->size(); \
ref->strides[0] = 1; \
}
/// Methods that provide direct access to pointers.
IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
#define IMPL_ADDELT(NAME, TYPE) \
void *_mlir_ciface_##NAME(void *tensor, TYPE value, \
/// Methods that provide direct access to indices.
IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
#undef IMPL_GETOVERHEAD
/// Helper to add value to coordinate scheme, one per value type.
#define IMPL_ADDELT(VNAME, V) \
void *_mlir_ciface_addElt##VNAME(void *coo, V value, \
StridedMemRefType<index_type, 1> *iref, \
StridedMemRefType<index_type, 1> *pref) { \
assert(tensor &&iref &&pref); \
assert(coo &&iref &&pref); \
assert(iref->strides[0] == 1 && pref->strides[0] == 1); \
assert(iref->sizes[0] == pref->sizes[0]); \
const index_type *indx = iref->data + iref->offset; \
@ -1477,21 +1468,33 @@ extern "C" {
std::vector<index_type> indices(isize); \
for (uint64_t r = 0; r < isize; r++) \
indices[perm[r]] = indx[r]; \
static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value); \
return tensor; \
static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value); \
return coo; \
}
FOREVERY_SIMPLEX_V(IMPL_ADDELT)
// `complex64` apparently doesn't encounter any ABI issues (yet).
IMPL_ADDELT(C64, complex64)
// TODO: cleaner way to avoid ABI padding problem?
IMPL_ADDELT(C32ABI, complex32)
void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
StridedMemRefType<index_type, 1> *iref,
StridedMemRefType<index_type, 1> *pref) {
return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
}
#undef IMPL_ADDELT
#define IMPL_GETNEXT(NAME, V) \
bool _mlir_ciface_##NAME(void *tensor, \
/// Helper to enumerate elements of coordinate scheme, one per value type.
#define IMPL_GETNEXT(VNAME, V) \
bool _mlir_ciface_getNext##VNAME(void *coo, \
StridedMemRefType<index_type, 1> *iref, \
StridedMemRefType<V, 0> *vref) { \
assert(tensor &&iref &&vref); \
assert(coo &&iref &&vref); \
assert(iref->strides[0] == 1); \
index_type *indx = iref->data + iref->offset; \
V *value = vref->data + vref->offset; \
const uint64_t isize = iref->sizes[0]; \
auto iter = static_cast<SparseTensorCOO<V> *>(tensor); \
const Element<V> *elem = iter->getNext(); \
const Element<V> *elem = \
static_cast<SparseTensorCOO<V> *>(coo)->getNext(); \
if (elem == nullptr) \
return false; \
for (uint64_t r = 0; r < isize; r++) \
@ -1499,19 +1502,34 @@ extern "C" {
*value = elem->value; \
return true; \
}
FOREVERY_V(IMPL_GETNEXT)
#undef IMPL_GETNEXT
#define IMPL_LEXINSERT(NAME, V) \
void _mlir_ciface_##NAME(void *tensor, \
StridedMemRefType<index_type, 1> *cref, V val) { \
/// Insert elements in lexicographical index order, one per value type.
#define IMPL_LEXINSERT(VNAME, V) \
void _mlir_ciface_lexInsert##VNAME( \
void *tensor, StridedMemRefType<index_type, 1> *cref, V val) { \
assert(tensor &&cref); \
assert(cref->strides[0] == 1); \
index_type *cursor = cref->data + cref->offset; \
assert(cursor); \
static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val); \
}
FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
// `complex64` apparently doesn't encounter any ABI issues (yet).
IMPL_LEXINSERT(C64, complex64)
// TODO: cleaner way to avoid ABI padding problem?
IMPL_LEXINSERT(C32ABI, complex32)
void _mlir_ciface_lexInsertC32(void *tensor,
StridedMemRefType<index_type, 1> *cref, float r,
float i) {
_mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
}
#undef IMPL_LEXINSERT
#define IMPL_EXPINSERT(NAME, V) \
void _mlir_ciface_##NAME( \
/// Insert using expansion, one per value type.
#define IMPL_EXPINSERT(VNAME, V) \
void _mlir_ciface_expInsert##VNAME( \
void *tensor, StridedMemRefType<index_type, 1> *cref, \
StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
StridedMemRefType<index_type, 1> *aref, index_type count) { \
@ -1528,6 +1546,8 @@ extern "C" {
static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \
cursor, values, filled, added, count); \
}
FOREVERY_V(IMPL_EXPINSERT)
#undef IMPL_EXPINSERT
// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
// can safely rewrite kIndex to kU64. We make this assertion to guarantee
@ -1658,122 +1678,16 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
fputs("unsupported combination of types\n", stderr);
exit(1);
}
/// Methods that provide direct access to pointers.
IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
/// Methods that provide direct access to indices.
IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
/// Methods that provide direct access to values.
IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
IMPL_SPARSEVALUES(sparseValuesC64, complex64, getValues)
IMPL_SPARSEVALUES(sparseValuesC32, complex32, getValues)
/// Helper to add value to coordinate scheme, one per value type.
IMPL_ADDELT(addEltF64, double)
IMPL_ADDELT(addEltF32, float)
IMPL_ADDELT(addEltI64, int64_t)
IMPL_ADDELT(addEltI32, int32_t)
IMPL_ADDELT(addEltI16, int16_t)
IMPL_ADDELT(addEltI8, int8_t)
IMPL_ADDELT(addEltC64, complex64)
IMPL_ADDELT(addEltC32ABI, complex32)
// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
// any padding (which seem to happen for complex32 when passed as scalar;
// all other cases, e.g. pointer to array, work as expected).
// TODO: cleaner way to avoid ABI padding problem?
void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
StridedMemRefType<index_type, 1> *iref,
StridedMemRefType<index_type, 1> *pref) {
return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
}
/// Helper to enumerate elements of coordinate scheme, one per value type.
IMPL_GETNEXT(getNextF64, double)
IMPL_GETNEXT(getNextF32, float)
IMPL_GETNEXT(getNextI64, int64_t)
IMPL_GETNEXT(getNextI32, int32_t)
IMPL_GETNEXT(getNextI16, int16_t)
IMPL_GETNEXT(getNextI8, int8_t)
IMPL_GETNEXT(getNextC64, complex64)
IMPL_GETNEXT(getNextC32, complex32)
/// Insert elements in lexicographical index order, one per value type.
IMPL_LEXINSERT(lexInsertF64, double)
IMPL_LEXINSERT(lexInsertF32, float)
IMPL_LEXINSERT(lexInsertI64, int64_t)
IMPL_LEXINSERT(lexInsertI32, int32_t)
IMPL_LEXINSERT(lexInsertI16, int16_t)
IMPL_LEXINSERT(lexInsertI8, int8_t)
IMPL_LEXINSERT(lexInsertC64, complex64)
IMPL_LEXINSERT(lexInsertC32ABI, complex32)
// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
// any padding (which seem to happen for complex32 when passed as scalar;
// all other cases, e.g. pointer to array, work as expected).
// TODO: cleaner way to avoid ABI padding problem?
void _mlir_ciface_lexInsertC32(void *tensor,
StridedMemRefType<index_type, 1> *cref, float r,
float i) {
_mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
}
/// Insert using expansion, one per value type.
IMPL_EXPINSERT(expInsertF64, double)
IMPL_EXPINSERT(expInsertF32, float)
IMPL_EXPINSERT(expInsertI64, int64_t)
IMPL_EXPINSERT(expInsertI32, int32_t)
IMPL_EXPINSERT(expInsertI16, int16_t)
IMPL_EXPINSERT(expInsertI8, int8_t)
IMPL_EXPINSERT(expInsertC64, complex64)
IMPL_EXPINSERT(expInsertC32, complex32)
#undef CASE
#undef IMPL_SPARSEVALUES
#undef IMPL_GETOVERHEAD
#undef IMPL_ADDELT
#undef IMPL_GETNEXT
#undef IMPL_LEXINSERT
#undef IMPL_EXPINSERT
#undef CASE_SECSAME
/// Output a sparse tensor, one per value type.
void outSparseTensorF64(void *tensor, void *dest, bool sort) {
return outSparseTensor<double>(tensor, dest, sort);
}
void outSparseTensorF32(void *tensor, void *dest, bool sort) {
return outSparseTensor<float>(tensor, dest, sort);
}
void outSparseTensorI64(void *tensor, void *dest, bool sort) {
return outSparseTensor<int64_t>(tensor, dest, sort);
}
void outSparseTensorI32(void *tensor, void *dest, bool sort) {
return outSparseTensor<int32_t>(tensor, dest, sort);
}
void outSparseTensorI16(void *tensor, void *dest, bool sort) {
return outSparseTensor<int16_t>(tensor, dest, sort);
}
void outSparseTensorI8(void *tensor, void *dest, bool sort) {
return outSparseTensor<int8_t>(tensor, dest, sort);
}
void outSparseTensorC64(void *tensor, void *dest, bool sort) {
return outSparseTensor<complex64>(tensor, dest, sort);
}
void outSparseTensorC32(void *tensor, void *dest, bool sort) {
return outSparseTensor<complex32>(tensor, dest, sort);
#define IMPL_OUTSPARSETENSOR(VNAME, V) \
void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \
return outSparseTensor<V>(coo, dest, sort); \
}
FOREVERY_V(IMPL_OUTSPARSETENSOR)
#undef IMPL_OUTSPARSETENSOR
//===----------------------------------------------------------------------===//
//
@ -1817,14 +1731,7 @@ void delSparseTensor(void *tensor) {
void delSparseTensorCOO##VNAME(void *coo) { \
delete static_cast<SparseTensorCOO<V> *>(coo); \
}
IMPL_DELCOO(F64, double)
IMPL_DELCOO(F32, float)
IMPL_DELCOO(I64, int64_t)
IMPL_DELCOO(I32, int32_t)
IMPL_DELCOO(I16, int16_t)
IMPL_DELCOO(I8, int8_t)
IMPL_DELCOO(C64, complex64)
IMPL_DELCOO(C32, complex32)
FOREVERY_V(IMPL_DELCOO)
#undef IMPL_DELCOO
/// Initializes sparse tensor from a COO-flavored format expressed using C-style
@ -1850,54 +1757,15 @@ IMPL_DELCOO(C32, complex32)
//
// TODO: generalize beyond 64-bit indices.
//
void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
double *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
float *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape,
int64_t *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<int64_t>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape,
int32_t *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<int32_t>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape,
int16_t *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<int16_t>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
int8_t *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorC64(uint64_t rank, uint64_t nse, uint64_t *shape,
complex64 *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<complex64>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
complex32 *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<complex32>(rank, nse, shape, values, indices, perm,
sparse);
#define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \
void *convertToMLIRSparseTensor##VNAME( \
uint64_t rank, uint64_t nse, uint64_t *shape, V *values, \
uint64_t *indices, uint64_t *perm, uint8_t *sparse) { \
return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm, \
sparse); \
}
FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
#undef IMPL_CONVERTTOMLIRSPARSETENSOR
/// Converts a sparse tensor to COO-flavored format expressed using C-style
/// data structures. The expected output parameters are pointers for these
@ -1919,48 +1787,14 @@ void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
// TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
// compressed
//
void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
double **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
float **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
int64_t **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<int64_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
int32_t **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<int32_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
int16_t **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<int16_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
int8_t **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorC64(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
complex64 **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<complex64>(tensor, pRank, pNse, pShape, pValues,
pIndices);
}
void convertFromMLIRSparseTensorC32(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
complex32 **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<complex32>(tensor, pRank, pNse, pShape, pValues,
pIndices);
#define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V) \
void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank, \
uint64_t *pNse, uint64_t **pShape, \
V **pValues, uint64_t **pIndices) { \
fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices); \
}
FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
#undef IMPL_CONVERTFROMMLIRSPARSETENSOR
} // extern "C"