forked from OSchip/llvm-project
[mlir][sparse] Factored out a "FATAL" macro for unrecoverable assertion failure
Depends On D126019 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D126022
This commit is contained in:
parent
88043c1958
commit
774674ce9a
|
@ -84,6 +84,17 @@ static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
|
||||||
return lhs * rhs;
|
return lhs * rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This macro helps minimize repetition of this idiom, as well as ensuring
|
||||||
|
// we have some additional output indicating where the error is coming from.
|
||||||
|
// (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
|
||||||
|
// to track down whether an error is coming from our code vs somewhere else
|
||||||
|
// in MLIR.)
|
||||||
|
#define FATAL(...) \
|
||||||
|
{ \
|
||||||
|
fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__); \
|
||||||
|
exit(1); \
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: adjust this so it can be used by `openSparseTensorCOO` too.
|
// TODO: adjust this so it can be used by `openSparseTensorCOO` too.
|
||||||
// That version doesn't have the permutation, and the `dimSizes` are
|
// That version doesn't have the permutation, and the `dimSizes` are
|
||||||
// a pointer/C-array rather than `std::vector`.
|
// a pointer/C-array rather than `std::vector`.
|
||||||
|
@ -262,6 +273,11 @@ private:
|
||||||
template <typename V>
|
template <typename V>
|
||||||
class SparseTensorEnumeratorBase;
|
class SparseTensorEnumeratorBase;
|
||||||
|
|
||||||
|
// Helper macro for generating error messages when some
|
||||||
|
// `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
|
||||||
|
// and then the wrong "partial method specialization" is called.
|
||||||
|
#define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
|
||||||
|
|
||||||
/// Abstract base class for `SparseTensorStorage<P,I,V>`. This class
|
/// Abstract base class for `SparseTensorStorage<P,I,V>`. This class
|
||||||
/// takes responsibility for all the `<P,I,V>`-independent aspects
|
/// takes responsibility for all the `<P,I,V>`-independent aspects
|
||||||
/// of the tensor (e.g., shape, sparsity, permutation). In addition,
|
/// of the tensor (e.g., shape, sparsity, permutation). In addition,
|
||||||
|
@ -325,37 +341,53 @@ public:
|
||||||
#define DECL_NEWENUMERATOR(VNAME, V) \
|
#define DECL_NEWENUMERATOR(VNAME, V) \
|
||||||
virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
|
virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
|
||||||
const uint64_t *) const { \
|
const uint64_t *) const { \
|
||||||
fatal("newEnumerator" #VNAME); \
|
FATAL_PIV("newEnumerator" #VNAME); \
|
||||||
}
|
}
|
||||||
FOREVERY_V(DECL_NEWENUMERATOR)
|
FOREVERY_V(DECL_NEWENUMERATOR)
|
||||||
#undef DECL_NEWENUMERATOR
|
#undef DECL_NEWENUMERATOR
|
||||||
|
|
||||||
/// Overhead storage.
|
/// Overhead storage.
|
||||||
virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
|
virtual void getPointers(std::vector<uint64_t> **, uint64_t) {
|
||||||
virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
|
FATAL_PIV("p64");
|
||||||
virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
|
}
|
||||||
virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
|
virtual void getPointers(std::vector<uint32_t> **, uint64_t) {
|
||||||
virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
|
FATAL_PIV("p32");
|
||||||
virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
|
}
|
||||||
virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
|
virtual void getPointers(std::vector<uint16_t> **, uint64_t) {
|
||||||
virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
|
FATAL_PIV("p16");
|
||||||
|
}
|
||||||
|
virtual void getPointers(std::vector<uint8_t> **, uint64_t) {
|
||||||
|
FATAL_PIV("p8");
|
||||||
|
}
|
||||||
|
virtual void getIndices(std::vector<uint64_t> **, uint64_t) {
|
||||||
|
FATAL_PIV("i64");
|
||||||
|
}
|
||||||
|
virtual void getIndices(std::vector<uint32_t> **, uint64_t) {
|
||||||
|
FATAL_PIV("i32");
|
||||||
|
}
|
||||||
|
virtual void getIndices(std::vector<uint16_t> **, uint64_t) {
|
||||||
|
FATAL_PIV("i16");
|
||||||
|
}
|
||||||
|
virtual void getIndices(std::vector<uint8_t> **, uint64_t) {
|
||||||
|
FATAL_PIV("i8");
|
||||||
|
}
|
||||||
|
|
||||||
/// Primary storage.
|
/// Primary storage.
|
||||||
#define DECL_GETVALUES(VNAME, V) \
|
#define DECL_GETVALUES(VNAME, V) \
|
||||||
virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
|
virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
|
||||||
FOREVERY_V(DECL_GETVALUES)
|
FOREVERY_V(DECL_GETVALUES)
|
||||||
#undef DECL_GETVALUES
|
#undef DECL_GETVALUES
|
||||||
|
|
||||||
/// Element-wise insertion in lexicographic index order.
|
/// Element-wise insertion in lexicographic index order.
|
||||||
#define DECL_LEXINSERT(VNAME, V) \
|
#define DECL_LEXINSERT(VNAME, V) \
|
||||||
virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
|
virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
|
||||||
FOREVERY_V(DECL_LEXINSERT)
|
FOREVERY_V(DECL_LEXINSERT)
|
||||||
#undef DECL_LEXINSERT
|
#undef DECL_LEXINSERT
|
||||||
|
|
||||||
/// Expanded insertion.
|
/// Expanded insertion.
|
||||||
#define DECL_EXPINSERT(VNAME, V) \
|
#define DECL_EXPINSERT(VNAME, V) \
|
||||||
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
|
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
|
||||||
fatal("expInsert" #VNAME); \
|
FATAL_PIV("expInsert" #VNAME); \
|
||||||
}
|
}
|
||||||
FOREVERY_V(DECL_EXPINSERT)
|
FOREVERY_V(DECL_EXPINSERT)
|
||||||
#undef DECL_EXPINSERT
|
#undef DECL_EXPINSERT
|
||||||
|
@ -374,16 +406,13 @@ protected:
|
||||||
SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
|
SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static void fatal(const char *tp) {
|
|
||||||
fprintf(stderr, "unsupported %s\n", tp);
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<uint64_t> dimSizes;
|
const std::vector<uint64_t> dimSizes;
|
||||||
std::vector<uint64_t> rev;
|
std::vector<uint64_t> rev;
|
||||||
const std::vector<DimLevelType> dimTypes;
|
const std::vector<DimLevelType> dimTypes;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#undef FATAL_PIV
|
||||||
|
|
||||||
// Forward.
|
// Forward.
|
||||||
template <typename P, typename I, typename V>
|
template <typename P, typename I, typename V>
|
||||||
class SparseTensorEnumerator;
|
class SparseTensorEnumerator;
|
||||||
|
@ -1122,10 +1151,8 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
|
||||||
char symmetry[64];
|
char symmetry[64];
|
||||||
// Read header line.
|
// Read header line.
|
||||||
if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
|
if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
|
||||||
symmetry) != 5) {
|
symmetry) != 5)
|
||||||
fprintf(stderr, "Corrupt header in %s\n", filename);
|
FATAL("Corrupt header in %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
// Set properties
|
// Set properties
|
||||||
*isPattern = (strcmp(toLower(field), "pattern") == 0);
|
*isPattern = (strcmp(toLower(field), "pattern") == 0);
|
||||||
*isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
|
*isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
|
||||||
|
@ -1134,26 +1161,20 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
|
||||||
strcmp(toLower(object), "matrix") ||
|
strcmp(toLower(object), "matrix") ||
|
||||||
strcmp(toLower(format), "coordinate") ||
|
strcmp(toLower(format), "coordinate") ||
|
||||||
(strcmp(toLower(field), "real") && !(*isPattern)) ||
|
(strcmp(toLower(field), "real") && !(*isPattern)) ||
|
||||||
(strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
|
(strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
|
||||||
fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
|
FATAL("Cannot find a general sparse matrix in %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
// Skip comments.
|
// Skip comments.
|
||||||
while (true) {
|
while (true) {
|
||||||
if (!fgets(line, kColWidth, file)) {
|
if (!fgets(line, kColWidth, file))
|
||||||
fprintf(stderr, "Cannot find data in %s\n", filename);
|
FATAL("Cannot find data in %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
if (line[0] != '%')
|
if (line[0] != '%')
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// Next line contains M N NNZ.
|
// Next line contains M N NNZ.
|
||||||
idata[0] = 2; // rank
|
idata[0] = 2; // rank
|
||||||
if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
|
if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
|
||||||
idata + 1) != 3) {
|
idata + 1) != 3)
|
||||||
fprintf(stderr, "Cannot find size in %s\n", filename);
|
FATAL("Cannot find size in %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read the "extended" FROSTT header. Although not part of the documented
|
/// Read the "extended" FROSTT header. Although not part of the documented
|
||||||
|
@ -1164,25 +1185,18 @@ static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
|
||||||
uint64_t *idata) {
|
uint64_t *idata) {
|
||||||
// Skip comments.
|
// Skip comments.
|
||||||
while (true) {
|
while (true) {
|
||||||
if (!fgets(line, kColWidth, file)) {
|
if (!fgets(line, kColWidth, file))
|
||||||
fprintf(stderr, "Cannot find data in %s\n", filename);
|
FATAL("Cannot find data in %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
if (line[0] != '#')
|
if (line[0] != '#')
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// Next line contains RANK and NNZ.
|
// Next line contains RANK and NNZ.
|
||||||
if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
|
if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
|
||||||
fprintf(stderr, "Cannot find metadata in %s\n", filename);
|
FATAL("Cannot find metadata in %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
// Followed by a line with the dimension sizes (one per rank).
|
// Followed by a line with the dimension sizes (one per rank).
|
||||||
for (uint64_t r = 0; r < idata[0]; r++) {
|
for (uint64_t r = 0; r < idata[0]; r++)
|
||||||
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
|
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
|
||||||
fprintf(stderr, "Cannot find dimension size %s\n", filename);
|
FATAL("Cannot find dimension size %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fgets(line, kColWidth, file); // end of line
|
fgets(line, kColWidth, file); // end of line
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1193,12 +1207,10 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
|
||||||
const uint64_t *shape,
|
const uint64_t *shape,
|
||||||
const uint64_t *perm) {
|
const uint64_t *perm) {
|
||||||
// Open the file.
|
// Open the file.
|
||||||
FILE *file = fopen(filename, "r");
|
|
||||||
if (!file) {
|
|
||||||
assert(filename && "Received nullptr for filename");
|
assert(filename && "Received nullptr for filename");
|
||||||
fprintf(stderr, "Cannot find file %s\n", filename);
|
FILE *file = fopen(filename, "r");
|
||||||
exit(1);
|
if (!file)
|
||||||
}
|
FATAL("Cannot find file %s\n", filename);
|
||||||
// Perform some file format dependent set up.
|
// Perform some file format dependent set up.
|
||||||
char line[kColWidth];
|
char line[kColWidth];
|
||||||
uint64_t idata[512];
|
uint64_t idata[512];
|
||||||
|
@ -1209,8 +1221,7 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
|
||||||
} else if (strstr(filename, ".tns")) {
|
} else if (strstr(filename, ".tns")) {
|
||||||
readExtFROSTTHeader(file, filename, line, idata);
|
readExtFROSTTHeader(file, filename, line, idata);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Unknown format %s\n", filename);
|
FATAL("Unknown format %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
}
|
||||||
// Prepare sparse tensor object with per-dimension sizes
|
// Prepare sparse tensor object with per-dimension sizes
|
||||||
// and the number of nonzeros as initial capacity.
|
// and the number of nonzeros as initial capacity.
|
||||||
|
@ -1224,10 +1235,8 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
|
||||||
// Read all nonzero elements.
|
// Read all nonzero elements.
|
||||||
std::vector<uint64_t> indices(rank);
|
std::vector<uint64_t> indices(rank);
|
||||||
for (uint64_t k = 0; k < nnz; k++) {
|
for (uint64_t k = 0; k < nnz; k++) {
|
||||||
if (!fgets(line, kColWidth, file)) {
|
if (!fgets(line, kColWidth, file))
|
||||||
fprintf(stderr, "Cannot find next line of data in %s\n", filename);
|
FATAL("Cannot find next line of data in %s\n", filename);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
char *linePtr = line;
|
char *linePtr = line;
|
||||||
for (uint64_t r = 0; r < rank; r++) {
|
for (uint64_t r = 0; r < rank; r++) {
|
||||||
uint64_t idx = strtoul(linePtr, &linePtr, 10);
|
uint64_t idx = strtoul(linePtr, &linePtr, 10);
|
||||||
|
@ -1290,22 +1299,15 @@ toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
|
||||||
// Verify that perm is a permutation of 0..(rank-1).
|
// Verify that perm is a permutation of 0..(rank-1).
|
||||||
std::vector<uint64_t> order(perm, perm + rank);
|
std::vector<uint64_t> order(perm, perm + rank);
|
||||||
std::sort(order.begin(), order.end());
|
std::sort(order.begin(), order.end());
|
||||||
for (uint64_t i = 0; i < rank; ++i) {
|
for (uint64_t i = 0; i < rank; ++i)
|
||||||
if (i != order[i]) {
|
if (i != order[i])
|
||||||
fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
|
FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify that the sparsity values are supported.
|
// Verify that the sparsity values are supported.
|
||||||
for (uint64_t i = 0; i < rank; ++i) {
|
for (uint64_t i = 0; i < rank; ++i)
|
||||||
if (sparsity[i] != DimLevelType::kDense &&
|
if (sparsity[i] != DimLevelType::kDense &&
|
||||||
sparsity[i] != DimLevelType::kCompressed) {
|
sparsity[i] != DimLevelType::kCompressed)
|
||||||
fprintf(stderr, "Unsupported sparsity value %d\n",
|
FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
|
||||||
static_cast<int>(sparsity[i]));
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Convert external format to internal COO.
|
// Convert external format to internal COO.
|
||||||
|
@ -1539,8 +1541,10 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
|
||||||
CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
|
CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
|
||||||
|
|
||||||
// Unsupported case (add above if needed).
|
// Unsupported case (add above if needed).
|
||||||
fputs("unsupported combination of types\n", stderr);
|
// TODO: better pretty-printing of enum values!
|
||||||
exit(1);
|
FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
|
||||||
|
static_cast<int>(ptrTp), static_cast<int>(indTp),
|
||||||
|
static_cast<int>(valTp));
|
||||||
}
|
}
|
||||||
#undef CASE
|
#undef CASE
|
||||||
#undef CASE_SECSAME
|
#undef CASE_SECSAME
|
||||||
|
@ -1704,10 +1708,8 @@ char *getTensorFilename(index_type id) {
|
||||||
char var[80];
|
char var[80];
|
||||||
sprintf(var, "TENSOR%" PRIu64, id);
|
sprintf(var, "TENSOR%" PRIu64, id);
|
||||||
char *env = getenv(var);
|
char *env = getenv(var);
|
||||||
if (!env) {
|
if (!env)
|
||||||
fprintf(stderr, "Environment variable %s is not set\n", var);
|
FATAL("Environment variable %s is not set\n", var);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
return env;
|
return env;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue