forked from OSchip/llvm-project
[mlir][sparse] Factored out a "FATAL" macro for unrecoverable assertion failure
Depends On D126019 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D126022
This commit is contained in:
parent
88043c1958
commit
774674ce9a
|
@ -84,6 +84,17 @@ static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
|
|||
return lhs * rhs;
|
||||
}
|
||||
|
||||
// This macro helps minimize repetition of this idiom, as well as ensuring
|
||||
// we have some additional output indicating where the error is coming from.
|
||||
// (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
|
||||
// to track down whether an error is coming from our code vs somewhere else
|
||||
// in MLIR.)
|
||||
#define FATAL(...) \
|
||||
{ \
|
||||
fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__); \
|
||||
exit(1); \
|
||||
}
|
||||
|
||||
// TODO: adjust this so it can be used by `openSparseTensorCOO` too.
|
||||
// That version doesn't have the permutation, and the `dimSizes` are
|
||||
// a pointer/C-array rather than `std::vector`.
|
||||
|
@ -262,6 +273,11 @@ private:
|
|||
template <typename V>
|
||||
class SparseTensorEnumeratorBase;
|
||||
|
||||
// Helper macro for generating error messages when some
|
||||
// `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
|
||||
// and then the wrong "partial method specialization" is called.
|
||||
#define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
|
||||
|
||||
/// Abstract base class for `SparseTensorStorage<P,I,V>`. This class
|
||||
/// takes responsibility for all the `<P,I,V>`-independent aspects
|
||||
/// of the tensor (e.g., shape, sparsity, permutation). In addition,
|
||||
|
@ -325,37 +341,53 @@ public:
|
|||
#define DECL_NEWENUMERATOR(VNAME, V) \
|
||||
virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
|
||||
const uint64_t *) const { \
|
||||
fatal("newEnumerator" #VNAME); \
|
||||
FATAL_PIV("newEnumerator" #VNAME); \
|
||||
}
|
||||
FOREVERY_V(DECL_NEWENUMERATOR)
|
||||
#undef DECL_NEWENUMERATOR
|
||||
|
||||
/// Overhead storage.
|
||||
virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
|
||||
virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
|
||||
virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
|
||||
virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
|
||||
virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
|
||||
virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
|
||||
virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
|
||||
virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
|
||||
virtual void getPointers(std::vector<uint64_t> **, uint64_t) {
|
||||
FATAL_PIV("p64");
|
||||
}
|
||||
virtual void getPointers(std::vector<uint32_t> **, uint64_t) {
|
||||
FATAL_PIV("p32");
|
||||
}
|
||||
virtual void getPointers(std::vector<uint16_t> **, uint64_t) {
|
||||
FATAL_PIV("p16");
|
||||
}
|
||||
virtual void getPointers(std::vector<uint8_t> **, uint64_t) {
|
||||
FATAL_PIV("p8");
|
||||
}
|
||||
virtual void getIndices(std::vector<uint64_t> **, uint64_t) {
|
||||
FATAL_PIV("i64");
|
||||
}
|
||||
virtual void getIndices(std::vector<uint32_t> **, uint64_t) {
|
||||
FATAL_PIV("i32");
|
||||
}
|
||||
virtual void getIndices(std::vector<uint16_t> **, uint64_t) {
|
||||
FATAL_PIV("i16");
|
||||
}
|
||||
virtual void getIndices(std::vector<uint8_t> **, uint64_t) {
|
||||
FATAL_PIV("i8");
|
||||
}
|
||||
|
||||
/// Primary storage.
|
||||
#define DECL_GETVALUES(VNAME, V) \
|
||||
virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
|
||||
virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
|
||||
FOREVERY_V(DECL_GETVALUES)
|
||||
#undef DECL_GETVALUES
|
||||
|
||||
/// Element-wise insertion in lexicographic index order.
|
||||
#define DECL_LEXINSERT(VNAME, V) \
|
||||
virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
|
||||
virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
|
||||
FOREVERY_V(DECL_LEXINSERT)
|
||||
#undef DECL_LEXINSERT
|
||||
|
||||
/// Expanded insertion.
|
||||
#define DECL_EXPINSERT(VNAME, V) \
|
||||
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
|
||||
fatal("expInsert" #VNAME); \
|
||||
FATAL_PIV("expInsert" #VNAME); \
|
||||
}
|
||||
FOREVERY_V(DECL_EXPINSERT)
|
||||
#undef DECL_EXPINSERT
|
||||
|
@ -374,16 +406,13 @@ protected:
|
|||
SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
|
||||
|
||||
private:
|
||||
static void fatal(const char *tp) {
|
||||
fprintf(stderr, "unsupported %s\n", tp);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const std::vector<uint64_t> dimSizes;
|
||||
std::vector<uint64_t> rev;
|
||||
const std::vector<DimLevelType> dimTypes;
|
||||
};
|
||||
|
||||
#undef FATAL_PIV
|
||||
|
||||
// Forward.
|
||||
template <typename P, typename I, typename V>
|
||||
class SparseTensorEnumerator;
|
||||
|
@ -1122,10 +1151,8 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
|
|||
char symmetry[64];
|
||||
// Read header line.
|
||||
if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
|
||||
symmetry) != 5) {
|
||||
fprintf(stderr, "Corrupt header in %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
symmetry) != 5)
|
||||
FATAL("Corrupt header in %s\n", filename);
|
||||
// Set properties
|
||||
*isPattern = (strcmp(toLower(field), "pattern") == 0);
|
||||
*isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
|
||||
|
@ -1134,26 +1161,20 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
|
|||
strcmp(toLower(object), "matrix") ||
|
||||
strcmp(toLower(format), "coordinate") ||
|
||||
(strcmp(toLower(field), "real") && !(*isPattern)) ||
|
||||
(strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
|
||||
fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
(strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
|
||||
FATAL("Cannot find a general sparse matrix in %s\n", filename);
|
||||
// Skip comments.
|
||||
while (true) {
|
||||
if (!fgets(line, kColWidth, file)) {
|
||||
fprintf(stderr, "Cannot find data in %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
if (!fgets(line, kColWidth, file))
|
||||
FATAL("Cannot find data in %s\n", filename);
|
||||
if (line[0] != '%')
|
||||
break;
|
||||
}
|
||||
// Next line contains M N NNZ.
|
||||
idata[0] = 2; // rank
|
||||
if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
|
||||
idata + 1) != 3) {
|
||||
fprintf(stderr, "Cannot find size in %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
idata + 1) != 3)
|
||||
FATAL("Cannot find size in %s\n", filename);
|
||||
}
|
||||
|
||||
/// Read the "extended" FROSTT header. Although not part of the documented
|
||||
|
@ -1164,25 +1185,18 @@ static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
|
|||
uint64_t *idata) {
|
||||
// Skip comments.
|
||||
while (true) {
|
||||
if (!fgets(line, kColWidth, file)) {
|
||||
fprintf(stderr, "Cannot find data in %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
if (!fgets(line, kColWidth, file))
|
||||
FATAL("Cannot find data in %s\n", filename);
|
||||
if (line[0] != '#')
|
||||
break;
|
||||
}
|
||||
// Next line contains RANK and NNZ.
|
||||
if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
|
||||
fprintf(stderr, "Cannot find metadata in %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
|
||||
FATAL("Cannot find metadata in %s\n", filename);
|
||||
// Followed by a line with the dimension sizes (one per rank).
|
||||
for (uint64_t r = 0; r < idata[0]; r++) {
|
||||
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
|
||||
fprintf(stderr, "Cannot find dimension size %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
for (uint64_t r = 0; r < idata[0]; r++)
|
||||
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
|
||||
FATAL("Cannot find dimension size %s\n", filename);
|
||||
fgets(line, kColWidth, file); // end of line
|
||||
}
|
||||
|
||||
|
@ -1193,12 +1207,10 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
|
|||
const uint64_t *shape,
|
||||
const uint64_t *perm) {
|
||||
// Open the file.
|
||||
assert(filename && "Received nullptr for filename");
|
||||
FILE *file = fopen(filename, "r");
|
||||
if (!file) {
|
||||
assert(filename && "Received nullptr for filename");
|
||||
fprintf(stderr, "Cannot find file %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
if (!file)
|
||||
FATAL("Cannot find file %s\n", filename);
|
||||
// Perform some file format dependent set up.
|
||||
char line[kColWidth];
|
||||
uint64_t idata[512];
|
||||
|
@ -1209,8 +1221,7 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
|
|||
} else if (strstr(filename, ".tns")) {
|
||||
readExtFROSTTHeader(file, filename, line, idata);
|
||||
} else {
|
||||
fprintf(stderr, "Unknown format %s\n", filename);
|
||||
exit(1);
|
||||
FATAL("Unknown format %s\n", filename);
|
||||
}
|
||||
// Prepare sparse tensor object with per-dimension sizes
|
||||
// and the number of nonzeros as initial capacity.
|
||||
|
@ -1224,10 +1235,8 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
|
|||
// Read all nonzero elements.
|
||||
std::vector<uint64_t> indices(rank);
|
||||
for (uint64_t k = 0; k < nnz; k++) {
|
||||
if (!fgets(line, kColWidth, file)) {
|
||||
fprintf(stderr, "Cannot find next line of data in %s\n", filename);
|
||||
exit(1);
|
||||
}
|
||||
if (!fgets(line, kColWidth, file))
|
||||
FATAL("Cannot find next line of data in %s\n", filename);
|
||||
char *linePtr = line;
|
||||
for (uint64_t r = 0; r < rank; r++) {
|
||||
uint64_t idx = strtoul(linePtr, &linePtr, 10);
|
||||
|
@ -1290,22 +1299,15 @@ toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
|
|||
// Verify that perm is a permutation of 0..(rank-1).
|
||||
std::vector<uint64_t> order(perm, perm + rank);
|
||||
std::sort(order.begin(), order.end());
|
||||
for (uint64_t i = 0; i < rank; ++i) {
|
||||
if (i != order[i]) {
|
||||
fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
for (uint64_t i = 0; i < rank; ++i)
|
||||
if (i != order[i])
|
||||
FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
|
||||
|
||||
// Verify that the sparsity values are supported.
|
||||
for (uint64_t i = 0; i < rank; ++i) {
|
||||
for (uint64_t i = 0; i < rank; ++i)
|
||||
if (sparsity[i] != DimLevelType::kDense &&
|
||||
sparsity[i] != DimLevelType::kCompressed) {
|
||||
fprintf(stderr, "Unsupported sparsity value %d\n",
|
||||
static_cast<int>(sparsity[i]));
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
sparsity[i] != DimLevelType::kCompressed)
|
||||
FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
|
||||
#endif
|
||||
|
||||
// Convert external format to internal COO.
|
||||
|
@ -1539,8 +1541,10 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
|
|||
CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
|
||||
|
||||
// Unsupported case (add above if needed).
|
||||
fputs("unsupported combination of types\n", stderr);
|
||||
exit(1);
|
||||
// TODO: better pretty-printing of enum values!
|
||||
FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
|
||||
static_cast<int>(ptrTp), static_cast<int>(indTp),
|
||||
static_cast<int>(valTp));
|
||||
}
|
||||
#undef CASE
|
||||
#undef CASE_SECSAME
|
||||
|
@ -1704,10 +1708,8 @@ char *getTensorFilename(index_type id) {
|
|||
char var[80];
|
||||
sprintf(var, "TENSOR%" PRIu64, id);
|
||||
char *env = getenv(var);
|
||||
if (!env) {
|
||||
fprintf(stderr, "Environment variable %s is not set\n", var);
|
||||
exit(1);
|
||||
}
|
||||
if (!env)
|
||||
FATAL("Environment variable %s is not set\n", var);
|
||||
return env;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue