[mlir][sparse] add missing types to from/to-MLIR conversion routines

This will enable our usual set of element types in external
environments, such as PyTACO support.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D124875
This commit is contained in:
Aart Bik 2022-05-03 14:01:47 -07:00
parent 3a8266902b
commit 1abcdc677c
1 changed files with 44 additions and 0 deletions

View File

@ -1339,6 +1339,30 @@ void *convertToMLIRSparseTensorF32(uint64_t rank, uint64_t nse, uint64_t *shape,
return toMLIRSparseTensor<float>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorI64(uint64_t rank, uint64_t nse, uint64_t *shape,
int64_t *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<int64_t>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorI32(uint64_t rank, uint64_t nse, uint64_t *shape,
int32_t *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<int32_t>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorI16(uint64_t rank, uint64_t nse, uint64_t *shape,
int16_t *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<int16_t>(rank, nse, shape, values, indices, perm,
sparse);
}
void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
int8_t *values, uint64_t *indices,
uint64_t *perm, uint8_t *sparse) {
return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
sparse);
}
/// Converts a sparse tensor to COO-flavored format expressed using C-style
/// data structures. The expected output parameters are pointers for these
@ -1370,6 +1394,26 @@ void convertFromMLIRSparseTensorF32(void *tensor, uint64_t *pRank,
float **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<float>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorI64(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
int64_t **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<int64_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorI32(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
int32_t **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<int32_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorI16(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
int16_t **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<int16_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
uint64_t *pNse, uint64_t **pShape,
int8_t **pValues, uint64_t **pIndices) {
fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
}
} // extern "C"