forked from OSchip/llvm-project
[mlir][sparse] Restyling macros in the runtime library
In addition to reducing code repetition, this also helps ensure that the various API functions follow the naming convention of mlir::sparse_tensor::primaryTypeFunctionSuffix (e.g., due to typos in the repetitious code). Depends On D125428 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D125431
This commit is contained in:
parent
7dbf2e7b57
commit
1313f5d307
|
@ -234,6 +234,29 @@ private:
|
|||
unsigned iteratorPos = 0;
|
||||
};
|
||||
|
||||
// See <https://en.wikipedia.org/wiki/X_Macro>
|
||||
//
|
||||
// `FOREVERY_SIMPLEX_V` only specifies the non-complex `V` types, because
|
||||
// the ABI for complex types has compiler/architecture dependent complexities
|
||||
// we need to work around. Namely, when a function takes a parameter of
|
||||
// C/C++ type `complex32` (per se), then there is additional padding that
|
||||
// causes it not to match the LLVM type `!llvm.struct<(f32, f32)>`. This
|
||||
// only happens with the `complex32` type itself, not with pointers/arrays
|
||||
// of complex values. So far `complex64` doesn't exhibit this ABI
|
||||
// incompatibility, but we exclude it anyways just to be safe.
|
||||
#define FOREVERY_SIMPLEX_V(DO) \
|
||||
DO(F64, double) \
|
||||
DO(F32, float) \
|
||||
DO(I64, int64_t) \
|
||||
DO(I32, int32_t) \
|
||||
DO(I16, int16_t) \
|
||||
DO(I8, int8_t)
|
||||
|
||||
#define FOREVERY_V(DO) \
|
||||
FOREVERY_SIMPLEX_V(DO) \
|
||||
DO(C64, complex64) \
|
||||
DO(C32, complex32)
|
||||
|
||||
// Forward.
|
||||
template <typename V>
|
||||
class SparseTensorEnumeratorBase;
|
||||
|
@ -298,38 +321,13 @@ public:
|
|||
}
|
||||
|
||||
/// Allocate a new enumerator.
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<double> **, uint64_t,
|
||||
const uint64_t *) const {
|
||||
fatal("enumf64");
|
||||
}
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<float> **, uint64_t,
|
||||
const uint64_t *) const {
|
||||
fatal("enumf32");
|
||||
}
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<int64_t> **, uint64_t,
|
||||
const uint64_t *) const {
|
||||
fatal("enumi64");
|
||||
}
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<int32_t> **, uint64_t,
|
||||
const uint64_t *) const {
|
||||
fatal("enumi32");
|
||||
}
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<int16_t> **, uint64_t,
|
||||
const uint64_t *) const {
|
||||
fatal("enumi16");
|
||||
}
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<int8_t> **, uint64_t,
|
||||
const uint64_t *) const {
|
||||
fatal("enumi8");
|
||||
}
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<complex64> **, uint64_t,
|
||||
const uint64_t *) const {
|
||||
fatal("enumc64");
|
||||
}
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<complex32> **, uint64_t,
|
||||
const uint64_t *) const {
|
||||
fatal("enumc32");
|
||||
#define DECL_NEWENUMERATOR(VNAME, V) \
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
|
||||
const uint64_t *) const { \
|
||||
fatal("newEnumerator" #VNAME); \
|
||||
}
|
||||
FOREVERY_V(DECL_NEWENUMERATOR)
|
||||
#undef DECL_NEWENUMERATOR
|
||||
|
||||
/// Overhead storage.
|
||||
virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
|
||||
|
@ -342,52 +340,24 @@ public:
|
|||
virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
|
||||
|
||||
/// Primary storage.
|
||||
virtual void getValues(std::vector<double> **) { fatal("valf64"); }
|
||||
virtual void getValues(std::vector<float> **) { fatal("valf32"); }
|
||||
virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
|
||||
virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
|
||||
virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
|
||||
virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
|
||||
virtual void getValues(std::vector<complex64> **) { fatal("valc64"); }
|
||||
virtual void getValues(std::vector<complex32> **) { fatal("valc32"); }
|
||||
#define DECL_GETVALUES(VNAME, V) \
|
||||
virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
|
||||
FOREVERY_V(DECL_GETVALUES)
|
||||
#undef DECL_GETVALUES
|
||||
|
||||
/// Element-wise insertion in lexicographic index order.
|
||||
virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
|
||||
virtual void lexInsert(const uint64_t *, float) { fatal("insf32"); }
|
||||
virtual void lexInsert(const uint64_t *, int64_t) { fatal("insi64"); }
|
||||
virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
|
||||
virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
|
||||
virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
|
||||
virtual void lexInsert(const uint64_t *, complex64) { fatal("insc64"); }
|
||||
virtual void lexInsert(const uint64_t *, complex32) { fatal("insc32"); }
|
||||
#define DECL_LEXINSERT(VNAME, V) \
|
||||
virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
|
||||
FOREVERY_V(DECL_LEXINSERT)
|
||||
#undef DECL_LEXINSERT
|
||||
|
||||
/// Expanded insertion.
|
||||
virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
|
||||
fatal("expf64");
|
||||
}
|
||||
virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
|
||||
fatal("expf32");
|
||||
}
|
||||
virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
|
||||
fatal("expi64");
|
||||
}
|
||||
virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
|
||||
fatal("expi32");
|
||||
}
|
||||
virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
|
||||
fatal("expi16");
|
||||
}
|
||||
virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
|
||||
fatal("expi8");
|
||||
}
|
||||
virtual void expInsert(uint64_t *, complex64 *, bool *, uint64_t *,
|
||||
uint64_t) {
|
||||
fatal("expc64");
|
||||
}
|
||||
virtual void expInsert(uint64_t *, complex32 *, bool *, uint64_t *,
|
||||
uint64_t) {
|
||||
fatal("expc32");
|
||||
#define DECL_EXPINSERT(VNAME, V) \
|
||||
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
|
||||
fatal("expInsert" #VNAME); \
|
||||
}
|
||||
FOREVERY_V(DECL_EXPINSERT)
|
||||
#undef DECL_EXPINSERT
|
||||
|
||||
/// Finishes insertion.
|
||||
virtual void endInsert() = 0;
|
||||
|
@ -1440,17 +1410,23 @@ extern "C" {
|
|||
}
|
||||
|
||||
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
|
||||
// TODO(D125432): move `_mlir_ciface_newSparseTensor` closer to these
|
||||
// macro definitions, but as a separate change so as not to muddy the diff.
|
||||
|
||||
#define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \
|
||||
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \
|
||||
/// Methods that provide direct access to values.
|
||||
#define IMPL_SPARSEVALUES(VNAME, V) \
|
||||
void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \
|
||||
void *tensor) { \
|
||||
assert(ref &&tensor); \
|
||||
std::vector<TYPE> *v; \
|
||||
static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \
|
||||
std::vector<V> *v; \
|
||||
static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
|
||||
ref->basePtr = ref->data = v->data(); \
|
||||
ref->offset = 0; \
|
||||
ref->sizes[0] = v->size(); \
|
||||
ref->strides[0] = 1; \
|
||||
}
|
||||
FOREVERY_V(IMPL_SPARSEVALUES)
|
||||
#undef IMPL_SPARSEVALUES
|
||||
|
||||
#define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \
|
||||
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
|
||||
|
@ -1463,12 +1439,27 @@ extern "C" {
|
|||
ref->sizes[0] = v->size(); \
|
||||
ref->strides[0] = 1; \
|
||||
}
|
||||
/// Methods that provide direct access to pointers.
|
||||
IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
|
||||
IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
|
||||
IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
|
||||
IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
|
||||
IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
|
||||
|
||||
#define IMPL_ADDELT(NAME, TYPE) \
|
||||
void *_mlir_ciface_##NAME(void *tensor, TYPE value, \
|
||||
StridedMemRefType<index_type, 1> *iref, \
|
||||
StridedMemRefType<index_type, 1> *pref) { \
|
||||
assert(tensor &&iref &&pref); \
|
||||
/// Methods that provide direct access to indices.
|
||||
IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
|
||||
IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
|
||||
IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
|
||||
IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
|
||||
IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
|
||||
#undef IMPL_GETOVERHEAD
|
||||
|
||||
/// Helper to add value to coordinate scheme, one per value type.
|
||||
#define IMPL_ADDELT(VNAME, V) \
|
||||
void *_mlir_ciface_addElt##VNAME(void *coo, V value, \
|
||||
StridedMemRefType<index_type, 1> *iref, \
|
||||
StridedMemRefType<index_type, 1> *pref) { \
|
||||
assert(coo &&iref &&pref); \
|
||||
assert(iref->strides[0] == 1 && pref->strides[0] == 1); \
|
||||
assert(iref->sizes[0] == pref->sizes[0]); \
|
||||
const index_type *indx = iref->data + iref->offset; \
|
||||
|
@ -1477,21 +1468,33 @@ extern "C" {
|
|||
std::vector<index_type> indices(isize); \
|
||||
for (uint64_t r = 0; r < isize; r++) \
|
||||
indices[perm[r]] = indx[r]; \
|
||||
static_cast<SparseTensorCOO<TYPE> *>(tensor)->add(indices, value); \
|
||||
return tensor; \
|
||||
static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value); \
|
||||
return coo; \
|
||||
}
|
||||
FOREVERY_SIMPLEX_V(IMPL_ADDELT)
|
||||
// `complex64` apparently doesn't encounter any ABI issues (yet).
|
||||
IMPL_ADDELT(C64, complex64)
|
||||
// TODO: cleaner way to avoid ABI padding problem?
|
||||
IMPL_ADDELT(C32ABI, complex32)
|
||||
void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
|
||||
StridedMemRefType<index_type, 1> *iref,
|
||||
StridedMemRefType<index_type, 1> *pref) {
|
||||
return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
|
||||
}
|
||||
#undef IMPL_ADDELT
|
||||
|
||||
#define IMPL_GETNEXT(NAME, V) \
|
||||
bool _mlir_ciface_##NAME(void *tensor, \
|
||||
StridedMemRefType<index_type, 1> *iref, \
|
||||
StridedMemRefType<V, 0> *vref) { \
|
||||
assert(tensor &&iref &&vref); \
|
||||
/// Helper to enumerate elements of coordinate scheme, one per value type.
|
||||
#define IMPL_GETNEXT(VNAME, V) \
|
||||
bool _mlir_ciface_getNext##VNAME(void *coo, \
|
||||
StridedMemRefType<index_type, 1> *iref, \
|
||||
StridedMemRefType<V, 0> *vref) { \
|
||||
assert(coo &&iref &&vref); \
|
||||
assert(iref->strides[0] == 1); \
|
||||
index_type *indx = iref->data + iref->offset; \
|
||||
V *value = vref->data + vref->offset; \
|
||||
const uint64_t isize = iref->sizes[0]; \
|
||||
auto iter = static_cast<SparseTensorCOO<V> *>(tensor); \
|
||||
const Element<V> *elem = iter->getNext(); \
|
||||
const Element<V> *elem = \
|
||||
static_cast<SparseTensorCOO<V> *>(coo)->getNext(); \
|
||||
if (elem == nullptr) \
|
||||
return false; \
|
||||
for (uint64_t r = 0; r < isize; r++) \
|
||||
|
@ -1499,19 +1502,34 @@ extern "C" {
|
|||
*value = elem->value; \
|
||||
return true; \
|
||||
}
|
||||
FOREVERY_V(IMPL_GETNEXT)
|
||||
#undef IMPL_GETNEXT
|
||||
|
||||
#define IMPL_LEXINSERT(NAME, V) \
|
||||
void _mlir_ciface_##NAME(void *tensor, \
|
||||
StridedMemRefType<index_type, 1> *cref, V val) { \
|
||||
/// Insert elements in lexicographical index order, one per value type.
|
||||
#define IMPL_LEXINSERT(VNAME, V) \
|
||||
void _mlir_ciface_lexInsert##VNAME( \
|
||||
void *tensor, StridedMemRefType<index_type, 1> *cref, V val) { \
|
||||
assert(tensor &&cref); \
|
||||
assert(cref->strides[0] == 1); \
|
||||
index_type *cursor = cref->data + cref->offset; \
|
||||
assert(cursor); \
|
||||
static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val); \
|
||||
}
|
||||
FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
|
||||
// `complex64` apparently doesn't encounter any ABI issues (yet).
|
||||
IMPL_LEXINSERT(C64, complex64)
|
||||
// TODO: cleaner way to avoid ABI padding problem?
|
||||
IMPL_LEXINSERT(C32ABI, complex32)
|
||||
void _mlir_ciface_lexInsertC32(void *tensor,
|
||||
StridedMemRefType<index_type, 1> *cref, float r,
|
||||
float i) {
|
||||
_mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
|
||||
}
|
||||
#undef IMPL_LEXINSERT
|
||||
|
||||
#define IMPL_EXPINSERT(NAME, V) \
|
||||
void _mlir_ciface_##NAME( \
|
||||
/// Insert using expansion, one per value type.
|
||||
#define IMPL_EXPINSERT(VNAME, V) \
|
||||
void _mlir_ciface_expInsert##VNAME( \
|
||||
void *tensor, StridedMemRefType<index_type, 1> *cref, \
|
||||
StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
|
||||
StridedMemRefType<index_type, 1> *aref, index_type count) { \
|
||||
|
@ -1528,6 +1546,8 @@ extern "C" {
|
|||
static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \
|
||||
cursor, values, filled, added, count); \
|
||||
}
|
||||
FOREVERY_V(IMPL_EXPINSERT)
|
||||
#undef IMPL_EXPINSERT
|
||||
|
||||
// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
|
||||
// can safely rewrite kIndex to kU64. We make this assertion to guarantee
|
||||
|
@ -1658,122 +1678,16 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
|
|||
fputs("unsupported combination of types\n", stderr);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
/// Methods that provide direct access to pointers.
|
||||
IMPL_GETOVERHEAD(sparsePointers, index_type, getPointers)
|
||||
IMPL_GETOVERHEAD(sparsePointers64, uint64_t, getPointers)
|
||||
IMPL_GETOVERHEAD(sparsePointers32, uint32_t, getPointers)
|
||||
IMPL_GETOVERHEAD(sparsePointers16, uint16_t, getPointers)
|
||||
IMPL_GETOVERHEAD(sparsePointers8, uint8_t, getPointers)
|
||||
|
||||
/// Methods that provide direct access to indices.
|
||||
IMPL_GETOVERHEAD(sparseIndices, index_type, getIndices)
|
||||
IMPL_GETOVERHEAD(sparseIndices64, uint64_t, getIndices)
|
||||
IMPL_GETOVERHEAD(sparseIndices32, uint32_t, getIndices)
|
||||
IMPL_GETOVERHEAD(sparseIndices16, uint16_t, getIndices)
|
||||
IMPL_GETOVERHEAD(sparseIndices8, uint8_t, getIndices)
|
||||
|
||||
/// Methods that provide direct access to values.
|
||||
IMPL_SPARSEVALUES(sparseValuesF64, double, getValues)
|
||||
IMPL_SPARSEVALUES(sparseValuesF32, float, getValues)
|
||||
IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
|
||||
IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
|
||||
IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
|
||||
IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
|
||||
IMPL_SPARSEVALUES(sparseValuesC64, complex64, getValues)
|
||||
IMPL_SPARSEVALUES(sparseValuesC32, complex32, getValues)
|
||||
|
||||
/// Helper to add value to coordinate scheme, one per value type.
|
||||
IMPL_ADDELT(addEltF64, double)
|
||||
IMPL_ADDELT(addEltF32, float)
|
||||
IMPL_ADDELT(addEltI64, int64_t)
|
||||
IMPL_ADDELT(addEltI32, int32_t)
|
||||
IMPL_ADDELT(addEltI16, int16_t)
|
||||
IMPL_ADDELT(addEltI8, int8_t)
|
||||
IMPL_ADDELT(addEltC64, complex64)
|
||||
IMPL_ADDELT(addEltC32ABI, complex32)
|
||||
// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
|
||||
// any padding (which seem to happen for complex32 when passed as scalar;
|
||||
// all other cases, e.g. pointer to array, work as expected).
|
||||
// TODO: cleaner way to avoid ABI padding problem?
|
||||
void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
|
||||
StridedMemRefType<index_type, 1> *iref,
|
||||
StridedMemRefType<index_type, 1> *pref) {
|
||||
return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
|
||||
}
|
||||
|
||||
/// Helper to enumerate elements of coordinate scheme, one per value type.
|
||||
IMPL_GETNEXT(getNextF64, double)
|
||||
IMPL_GETNEXT(getNextF32, float)
|
||||
IMPL_GETNEXT(getNextI64, int64_t)
|
||||
IMPL_GETNEXT(getNextI32, int32_t)
|
||||
IMPL_GETNEXT(getNextI16, int16_t)
|
||||
IMPL_GETNEXT(getNextI8, int8_t)
|
||||
IMPL_GETNEXT(getNextC64, complex64)
|
||||
IMPL_GETNEXT(getNextC32, complex32)
|
||||
|
||||
/// Insert elements in lexicographical index order, one per value type.
|
||||
IMPL_LEXINSERT(lexInsertF64, double)
|
||||
IMPL_LEXINSERT(lexInsertF32, float)
|
||||
IMPL_LEXINSERT(lexInsertI64, int64_t)
|
||||
IMPL_LEXINSERT(lexInsertI32, int32_t)
|
||||
IMPL_LEXINSERT(lexInsertI16, int16_t)
|
||||
IMPL_LEXINSERT(lexInsertI8, int8_t)
|
||||
IMPL_LEXINSERT(lexInsertC64, complex64)
|
||||
IMPL_LEXINSERT(lexInsertC32ABI, complex32)
|
||||
// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
|
||||
// any padding (which seem to happen for complex32 when passed as scalar;
|
||||
// all other cases, e.g. pointer to array, work as expected).
|
||||
// TODO: cleaner way to avoid ABI padding problem?
|
||||
void _mlir_ciface_lexInsertC32(void *tensor,
|
||||
StridedMemRefType<index_type, 1> *cref, float r,
|
||||
float i) {
|
||||
_mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
|
||||
}
|
||||
|
||||
/// Insert using expansion, one per value type.
|
||||
IMPL_EXPINSERT(expInsertF64, double)
|
||||
IMPL_EXPINSERT(expInsertF32, float)
|
||||
IMPL_EXPINSERT(expInsertI64, int64_t)
|
||||
IMPL_EXPINSERT(expInsertI32, int32_t)
|
||||
IMPL_EXPINSERT(expInsertI16, int16_t)
|
||||
IMPL_EXPINSERT(expInsertI8, int8_t)
|
||||
IMPL_EXPINSERT(expInsertC64, complex64)
|
||||
IMPL_EXPINSERT(expInsertC32, complex32)
|
||||
|
||||
#undef CASE
|
||||
#undef IMPL_SPARSEVALUES
|
||||
#undef IMPL_GETOVERHEAD
|
||||
#undef IMPL_ADDELT
|
||||
#undef IMPL_GETNEXT
|
||||
#undef IMPL_LEXINSERT
|
||||
#undef IMPL_EXPINSERT
|
||||
#undef CASE_SECSAME
|
||||
|
||||
/// Output a sparse tensor, one per value type.
|
||||
void outSparseTensorF64(void *tensor, void *dest, bool sort) {
|
||||
return outSparseTensor<double>(tensor, dest, sort);
|
||||
}
|
||||
void outSparseTensorF32(void *tensor, void *dest, bool sort) {
|
||||
return outSparseTensor<float>(tensor, dest, sort);
|
||||
}
|
||||
void outSparseTensorI64(void *tensor, void *dest, bool sort) {
|
||||
return outSparseTensor<int64_t>(tensor, dest, sort);
|
||||
}
|
||||
void outSparseTensorI32(void *tensor, void *dest, bool sort) {
|
||||
return outSparseTensor<int32_t>(tensor, dest, sort);
|
||||
}
|
||||
void outSparseTensorI16(void *tensor, void *dest, bool sort) {
|
||||
return outSparseTensor<int16_t>(tensor, dest, sort);
|
||||
}
|
||||
void outSparseTensorI8(void *tensor, void *dest, bool sort) {
|
||||
return outSparseTensor<int8_t>(tensor, dest, sort);
|
||||
}
|
||||
void outSparseTensorC64(void *tensor, void *dest, bool sort) {
|
||||
return outSparseTensor<complex64>(tensor, dest, sort);
|
||||
}
|
||||
void outSparseTensorC32(void *tensor, void *dest, bool sort) {
|
||||
return outSparseTensor<complex32>(tensor, dest, sort);
|
||||
}
|
||||
#define IMPL_OUTSPARSETENSOR(VNAME, V) \
|
||||
void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \
|
||||
return outSparseTensor<V>(coo, dest, sort); \
|
||||
}
|
||||
FOREVERY_V(IMPL_OUTSPARSETENSOR)
|
||||
#undef IMPL_OUTSPARSETENSOR
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
|
@ -1817,14 +1731,7 @@ void delSparseTensor(void *tensor) {
|
|||
void delSparseTensorCOO##VNAME(void *coo) { \
|
||||
delete static_cast<SparseTensorCOO<V> *>(coo); \
|
||||
}
|
||||
IMPL_DELCOO(F64, double)
|
||||
IMPL_DELCOO(F32, float)
|
||||
IMPL_DELCOO(I64, int64_t)
|
||||
IMPL_DELCOO(I32, int32_t)
|
||||
IMPL_DELCOO(I16, int16_t)
|
||||
IMPL_DELCOO(I8, int8_t)
|
||||
IMPL_DELCOO(C64, complex64)
|
||||
IMPL_DELCOO(C32, complex32)
|
||||
FOREVERY_V(IMPL_DELCOO)
|
||||
#undef IMPL_DELCOO
|
||||
|
||||
/// Initializes sparse tensor from a COO-flavored format expressed using C-style
|
||||
|
@ -1850,54 +1757,15 @@ IMPL_DELCOO(C32, complex32)
|
|||
//
|
||||
// TODO: generalize beyond 64-bit indices.
|
||||
//
|
||||
void *convertToMLIRSparseTensorF64(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
double *values, uint64_t *indices,
|
||||
uint64_t *perm, uint8_t *sparse) {
|
||||
return toMLIRSparseTensor<double>(rank, nse, shape, values, indices, perm,
|
||||
sparse);
|
||||
}
|
||||
void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
float *values, uint64_t *indices,
|
||||
uint64_t *perm, uint8_t *sparse) {
|
||||
return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
|
||||
sparse);
|
||||
}
|
||||
void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
int64_t *values, uint64_t *indices,
|
||||
uint64_t *perm, uint8_t *sparse) {
|
||||
return toMLIRSparseTensor<int64_t>(rank, nse, shape, values, indices, perm,
|
||||
sparse);
|
||||
}
|
||||
void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
int32_t *values, uint64_t *indices,
|
||||
uint64_t *perm, uint8_t *sparse) {
|
||||
return toMLIRSparseTensor<int32_t>(rank, nse, shape, values, indices, perm,
|
||||
sparse);
|
||||
}
|
||||
void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
int16_t *values, uint64_t *indices,
|
||||
uint64_t *perm, uint8_t *sparse) {
|
||||
return toMLIRSparseTensor<int16_t>(rank, nse, shape, values, indices, perm,
|
||||
sparse);
|
||||
}
|
||||
void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
int8_t *values, uint64_t *indices,
|
||||
uint64_t *perm, uint8_t *sparse) {
|
||||
return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
|
||||
sparse);
|
||||
}
|
||||
void *convertToMLIRSparseTensorC64(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
complex64 *values, uint64_t *indices,
|
||||
uint64_t *perm, uint8_t *sparse) {
|
||||
return toMLIRSparseTensor<complex64>(rank, nse, shape, values, indices, perm,
|
||||
sparse);
|
||||
}
|
||||
void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
|
||||
complex32 *values, uint64_t *indices,
|
||||
uint64_t *perm, uint8_t *sparse) {
|
||||
return toMLIRSparseTensor<complex32>(rank, nse, shape, values, indices, perm,
|
||||
sparse);
|
||||
}
|
||||
#define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \
|
||||
void *convertToMLIRSparseTensor##VNAME( \
|
||||
uint64_t rank, uint64_t nse, uint64_t *shape, V *values, \
|
||||
uint64_t *indices, uint64_t *perm, uint8_t *sparse) { \
|
||||
return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm, \
|
||||
sparse); \
|
||||
}
|
||||
FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
|
||||
#undef IMPL_CONVERTTOMLIRSPARSETENSOR
|
||||
|
||||
/// Converts a sparse tensor to COO-flavored format expressed using C-style
|
||||
/// data structures. The expected output parameters are pointers for these
|
||||
|
@ -1919,48 +1787,14 @@ void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
|
|||
// TODO: generalize beyond 64-bit indices, no dim ordering, all dimensions
|
||||
// compressed
|
||||
//
|
||||
void convertFromMLIRSparseTensorF64(void *tensor, uint64_t *pRank,
|
||||
uint64_t *pNse, uint64_t **pShape,
|
||||
double **pValues, uint64_t **pIndices) {
|
||||
fromMLIRSparseTensor<double>(tensor, pRank, pNse, pShape, pValues, pIndices);
|
||||
}
|
||||
void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
|
||||
uint64_t *pNse, uint64_t **pShape,
|
||||
float **pValues, uint64_t **pIndices) {
|
||||
fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
|
||||
}
|
||||
void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank,
|
||||
uint64_t *pNse, uint64_t **pShape,
|
||||
int64_t **pValues, uint64_t **pIndices) {
|
||||
fromMLIRSparseTensor<int64_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
|
||||
}
|
||||
void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank,
|
||||
uint64_t *pNse, uint64_t **pShape,
|
||||
int32_t **pValues, uint64_t **pIndices) {
|
||||
fromMLIRSparseTensor<int32_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
|
||||
}
|
||||
void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank,
|
||||
uint64_t *pNse, uint64_t **pShape,
|
||||
int16_t **pValues, uint64_t **pIndices) {
|
||||
fromMLIRSparseTensor<int16_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
|
||||
}
|
||||
void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
|
||||
uint64_t *pNse, uint64_t **pShape,
|
||||
int8_t **pValues, uint64_t **pIndices) {
|
||||
fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
|
||||
}
|
||||
void convertFromMLIRSparseTensorC64(void *tensor, uint64_t *pRank,
|
||||
uint64_t *pNse, uint64_t **pShape,
|
||||
complex64 **pValues, uint64_t **pIndices) {
|
||||
fromMLIRSparseTensor<complex64>(tensor, pRank, pNse, pShape, pValues,
|
||||
pIndices);
|
||||
}
|
||||
void convertFromMLIRSparseTensorC32(void *tensor, uint64_t *pRank,
|
||||
uint64_t *pNse, uint64_t **pShape,
|
||||
complex32 **pValues, uint64_t **pIndices) {
|
||||
fromMLIRSparseTensor<complex32>(tensor, pRank, pNse, pShape, pValues,
|
||||
pIndices);
|
||||
}
|
||||
#define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V) \
|
||||
void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank, \
|
||||
uint64_t *pNse, uint64_t **pShape, \
|
||||
V **pValues, uint64_t **pIndices) { \
|
||||
fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices); \
|
||||
}
|
||||
FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
|
||||
#undef IMPL_CONVERTFROMMLIRSPARSETENSOR
|
||||
|
||||
} // extern "C"
|
||||
|
||||
|
|
Loading…
Reference in New Issue