[MLIR][PDL] Make predicate order deterministic.

The tree merging of pattern predicates places the predicates in an unordered set. When the predicates are sorted, they are taken in the set order, not the insertion order. This results in nondeterministic behavior.

One solution to this problem would be to use `SetVector`. However, the value `SetVector` does not provide a `find` function for fast O(1) lookups and stores the predicates twice -- once in the set and once in the vector, which is undesirable, because we store patternToAnswer in each predicate. A simpler solution is to store the tie breaking ID (which follows the insertion order), and use this ID to break any ties when comparing predicates.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116081
This commit is contained in:
Stanislav Funiak 2022-01-04 08:03:26 +05:30 committed by Uday Bondhugula
parent 2692eae574
commit 138803e017
1 changed files with 14 additions and 5 deletions

View File

@ -721,6 +721,11 @@ struct OrderedPredicate {
/// opposed to those shared across patterns.
unsigned secondary = 0;
/// The tie breaking ID, used to preserve a deterministic (insertion) order
/// among all the predicates with the same priority, depth, and position /
/// predicate dependency.
unsigned id = 0;
/// A map between a pattern operation and the answer to the predicate question
/// within that pattern.
DenseMap<Operation *, Qualifier *> patternToAnswer;
@ -733,12 +738,13 @@ struct OrderedPredicate {
// * lower depth
// * lower position dependency
// * lower predicate dependency
// * lower tie breaking ID
auto *rhsPos = rhs.position;
return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
rhsPos->getKind(), rhs.question->getKind()) >
rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
std::make_tuple(rhs.primary, rhs.secondary,
position->getOperationDepth(), position->getKind(),
question->getKind());
question->getKind(), id);
}
};
@ -903,6 +909,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
auto it = uniqued.insert(predicate);
it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
predicate.answer);
// Mark the insertion order (0-based indexing).
if (it.second)
it.first->id = uniqued.size() - 1;
}
}
@ -939,9 +948,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
ordered.reserve(uniqued.size());
for (auto &ip : uniqued)
ordered.push_back(&ip);
std::stable_sort(
ordered.begin(), ordered.end(),
[](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
return *lhs < *rhs;
});
// Build the matchers for each of the pattern predicate lists.
std::unique_ptr<MatcherNode> root;