[mlir][sparse] Factored out a "FATAL" macro for unrecoverable assertion failure

Depends On D126019

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D126022
This commit is contained in:
wren romano 2022-05-19 15:01:23 -07:00
parent 88043c1958
commit 774674ce9a
1 changed files with 79 additions and 77 deletions

View File

@ -84,6 +84,17 @@ static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
return lhs * rhs;
}
// This macro helps minimize repetition of this idiom, as well as ensuring
// we have some additional output indicating where the error is coming from.
// (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
// to track down whether an error is coming from our code vs somewhere else
// in MLIR.)
#define FATAL(...) \
{ \
fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__); \
exit(1); \
}
// TODO: adjust this so it can be used by `openSparseTensorCOO` too.
// That version doesn't have the permutation, and the `dimSizes` are
// a pointer/C-array rather than `std::vector`.
@ -262,6 +273,11 @@ private:
template <typename V>
class SparseTensorEnumeratorBase;
// Helper macro for generating error messages when some
// `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
// and then the wrong "partial method specialization" is called.
#define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
/// Abstract base class for `SparseTensorStorage<P,I,V>`. This class
/// takes responsibility for all the `<P,I,V>`-independent aspects
/// of the tensor (e.g., shape, sparsity, permutation). In addition,
@ -325,37 +341,53 @@ public:
#define DECL_NEWENUMERATOR(VNAME, V) \
virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
const uint64_t *) const { \
fatal("newEnumerator" #VNAME); \
FATAL_PIV("newEnumerator" #VNAME); \
}
FOREVERY_V(DECL_NEWENUMERATOR)
#undef DECL_NEWENUMERATOR
/// Overhead storage.
virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
virtual void getPointers(std::vector<uint64_t> **, uint64_t) {
FATAL_PIV("p64");
}
virtual void getPointers(std::vector<uint32_t> **, uint64_t) {
FATAL_PIV("p32");
}
virtual void getPointers(std::vector<uint16_t> **, uint64_t) {
FATAL_PIV("p16");
}
virtual void getPointers(std::vector<uint8_t> **, uint64_t) {
FATAL_PIV("p8");
}
virtual void getIndices(std::vector<uint64_t> **, uint64_t) {
FATAL_PIV("i64");
}
virtual void getIndices(std::vector<uint32_t> **, uint64_t) {
FATAL_PIV("i32");
}
virtual void getIndices(std::vector<uint16_t> **, uint64_t) {
FATAL_PIV("i16");
}
virtual void getIndices(std::vector<uint8_t> **, uint64_t) {
FATAL_PIV("i8");
}
/// Primary storage.
#define DECL_GETVALUES(VNAME, V) \
virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
FOREVERY_V(DECL_GETVALUES)
#undef DECL_GETVALUES
/// Element-wise insertion in lexicographic index order.
#define DECL_LEXINSERT(VNAME, V) \
virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
FOREVERY_V(DECL_LEXINSERT)
#undef DECL_LEXINSERT
/// Expanded insertion.
#define DECL_EXPINSERT(VNAME, V) \
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
fatal("expInsert" #VNAME); \
FATAL_PIV("expInsert" #VNAME); \
}
FOREVERY_V(DECL_EXPINSERT)
#undef DECL_EXPINSERT
@ -374,16 +406,13 @@ protected:
SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
private:
static void fatal(const char *tp) {
fprintf(stderr, "unsupported %s\n", tp);
exit(1);
}
const std::vector<uint64_t> dimSizes;
std::vector<uint64_t> rev;
const std::vector<DimLevelType> dimTypes;
};
#undef FATAL_PIV
// Forward.
template <typename P, typename I, typename V>
class SparseTensorEnumerator;
@ -1122,10 +1151,8 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
char symmetry[64];
// Read header line.
if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
symmetry) != 5) {
fprintf(stderr, "Corrupt header in %s\n", filename);
exit(1);
}
symmetry) != 5)
FATAL("Corrupt header in %s\n", filename);
// Set properties
*isPattern = (strcmp(toLower(field), "pattern") == 0);
*isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
@ -1134,26 +1161,20 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
strcmp(toLower(object), "matrix") ||
strcmp(toLower(format), "coordinate") ||
(strcmp(toLower(field), "real") && !(*isPattern)) ||
(strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
exit(1);
}
(strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
FATAL("Cannot find a general sparse matrix in %s\n", filename);
// Skip comments.
while (true) {
if (!fgets(line, kColWidth, file)) {
fprintf(stderr, "Cannot find data in %s\n", filename);
exit(1);
}
if (!fgets(line, kColWidth, file))
FATAL("Cannot find data in %s\n", filename);
if (line[0] != '%')
break;
}
// Next line contains M N NNZ.
idata[0] = 2; // rank
if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
idata + 1) != 3) {
fprintf(stderr, "Cannot find size in %s\n", filename);
exit(1);
}
idata + 1) != 3)
FATAL("Cannot find size in %s\n", filename);
}
/// Read the "extended" FROSTT header. Although not part of the documented
@ -1164,25 +1185,18 @@ static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
uint64_t *idata) {
// Skip comments.
while (true) {
if (!fgets(line, kColWidth, file)) {
fprintf(stderr, "Cannot find data in %s\n", filename);
exit(1);
}
if (!fgets(line, kColWidth, file))
FATAL("Cannot find data in %s\n", filename);
if (line[0] != '#')
break;
}
// Next line contains RANK and NNZ.
if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
fprintf(stderr, "Cannot find metadata in %s\n", filename);
exit(1);
}
if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
FATAL("Cannot find metadata in %s\n", filename);
// Followed by a line with the dimension sizes (one per rank).
for (uint64_t r = 0; r < idata[0]; r++) {
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
fprintf(stderr, "Cannot find dimension size %s\n", filename);
exit(1);
}
}
for (uint64_t r = 0; r < idata[0]; r++)
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
FATAL("Cannot find dimension size %s\n", filename);
fgets(line, kColWidth, file); // end of line
}
@ -1193,12 +1207,10 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
const uint64_t *shape,
const uint64_t *perm) {
// Open the file.
assert(filename && "Received nullptr for filename");
FILE *file = fopen(filename, "r");
if (!file) {
assert(filename && "Received nullptr for filename");
fprintf(stderr, "Cannot find file %s\n", filename);
exit(1);
}
if (!file)
FATAL("Cannot find file %s\n", filename);
// Perform some file format dependent set up.
char line[kColWidth];
uint64_t idata[512];
@ -1209,8 +1221,7 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
} else if (strstr(filename, ".tns")) {
readExtFROSTTHeader(file, filename, line, idata);
} else {
fprintf(stderr, "Unknown format %s\n", filename);
exit(1);
FATAL("Unknown format %s\n", filename);
}
// Prepare sparse tensor object with per-dimension sizes
// and the number of nonzeros as initial capacity.
@ -1224,10 +1235,8 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
// Read all nonzero elements.
std::vector<uint64_t> indices(rank);
for (uint64_t k = 0; k < nnz; k++) {
if (!fgets(line, kColWidth, file)) {
fprintf(stderr, "Cannot find next line of data in %s\n", filename);
exit(1);
}
if (!fgets(line, kColWidth, file))
FATAL("Cannot find next line of data in %s\n", filename);
char *linePtr = line;
for (uint64_t r = 0; r < rank; r++) {
uint64_t idx = strtoul(linePtr, &linePtr, 10);
@ -1290,22 +1299,15 @@ toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
// Verify that perm is a permutation of 0..(rank-1).
std::vector<uint64_t> order(perm, perm + rank);
std::sort(order.begin(), order.end());
for (uint64_t i = 0; i < rank; ++i) {
if (i != order[i]) {
fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
exit(1);
}
}
for (uint64_t i = 0; i < rank; ++i)
if (i != order[i])
FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
// Verify that the sparsity values are supported.
for (uint64_t i = 0; i < rank; ++i) {
for (uint64_t i = 0; i < rank; ++i)
if (sparsity[i] != DimLevelType::kDense &&
sparsity[i] != DimLevelType::kCompressed) {
fprintf(stderr, "Unsupported sparsity value %d\n",
static_cast<int>(sparsity[i]));
exit(1);
}
}
sparsity[i] != DimLevelType::kCompressed)
FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
#endif
// Convert external format to internal COO.
@ -1539,8 +1541,10 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
// Unsupported case (add above if needed).
fputs("unsupported combination of types\n", stderr);
exit(1);
// TODO: better pretty-printing of enum values!
FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
static_cast<int>(ptrTp), static_cast<int>(indTp),
static_cast<int>(valTp));
}
#undef CASE
#undef CASE_SECSAME
@ -1704,10 +1708,8 @@ char *getTensorFilename(index_type id) {
char var[80];
sprintf(var, "TENSOR%" PRIu64, id);
char *env = getenv(var);
if (!env) {
fprintf(stderr, "Environment variable %s is not set\n", var);
exit(1);
}
if (!env)
FATAL("Environment variable %s is not set\n", var);
return env;
}