From 138803e017739c81b43b73631c7096bfc4d097d8 Mon Sep 17 00:00:00 2001 From: Stanislav Funiak Date: Tue, 4 Jan 2022 08:03:26 +0530 Subject: [PATCH] [MLIR][PDL] Make predicate order deterministic. The tree merging of pattern predicates places the predicates in an unordered set. When the predicates are sorted, they are taken in the set order, not the insertion order. This results in nondeterministic behavior. One solution to this problem would be to use `SetVector`. However, the value `SetVector` does not provide a `find` function for fast O(1) lookups and stores the predicates twice -- once in the set and once in the vector, which is undesirable, because we store patternToAnswer in each predicate. A simpler solution is to store the tie breaking ID (which follows the insertion order), and use this ID to break any ties when comparing predicates. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D116081 --- .../PDLToPDLInterp/PredicateTree.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 24b2f19e58c2..9fd5de11a83d 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -721,6 +721,11 @@ struct OrderedPredicate { /// opposed to those shared across patterns. unsigned secondary = 0; + /// The tie breaking ID, used to preserve a deterministic (insertion) order + /// among all the predicates with the same priority, depth, and position / + /// predicate dependency. + unsigned id = 0; + /// A map between a pattern operation and the answer to the predicate question /// within that pattern. DenseMap patternToAnswer; @@ -733,12 +738,13 @@ struct OrderedPredicate { // * lower depth // * lower position dependency // * lower predicate dependency + // * lower tie breaking ID auto *rhsPos = rhs.position; return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(), - rhsPos->getKind(), rhs.question->getKind()) > + rhsPos->getKind(), rhs.question->getKind(), rhs.id) > std::make_tuple(rhs.primary, rhs.secondary, position->getOperationDepth(), position->getKind(), - question->getKind()); + question->getKind(), id); } }; @@ -903,6 +909,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, auto it = uniqued.insert(predicate); it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, predicate.answer); + // Mark the insertion order (0-based indexing). + if (it.second) + it.first->id = uniqued.size() - 1; } } @@ -939,9 +948,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, ordered.reserve(uniqued.size()); for (auto &ip : uniqued) ordered.push_back(&ip); - std::stable_sort( - ordered.begin(), ordered.end(), - [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; }); + llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) { + return *lhs < *rhs; + }); // Build the matchers for each of the pattern predicate lists. std::unique_ptr root;