[mlir][sparse] Only try to compute a better iteraton graph when needed

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134059
This commit is contained in:
Peiming Liu 2022-09-16 18:21:59 +00:00
parent 70f1f302ca
commit b1d1964771
1 changed files with 15 additions and 11 deletions

View File

@ -1832,26 +1832,30 @@ public:
// code generation can proceed. As a last resort, an attempt is made
// to resolve cycles by inserting a conversion.
std::vector<unsigned> topSort;
// Whether the current GenericOp is admissible
// Whether the current GenericOp is admissible.
bool isAdmissible = false;
bool hasCycle = true;
// An const list of all masks that we used for interation graph
// computation. Must be ordered from strict -> loose.
const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
SortMask::kIncludeDense, SortMask::kSparseOnly};
for (auto mask : allMask) {
if (computeIterationGraph(merger, op, topSort, mask) &&
isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
outerParNest)) {
// This is an admissible GenericOp.
isAdmissible = true;
break;
for (auto mask : allMask)
if (computeIterationGraph(merger, op, topSort, mask)) {
hasCycle = false;
if (isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
outerParNest)) {
isAdmissible = true;
break;
}
// else try a set of less strict constraints.
}
// else try a less strict constraints.
}
if (!isAdmissible)
if (hasCycle)
// Give it one last shot to resolve the cycle.
return resolveCycle(merger, rewriter, op);
if (!isAdmissible)
// Inadmissible expression, reject.
return failure();
// Recursively generates code if admissible.
merger.setHasSparseOut(sparseOut != nullptr);