From a76ee58f3cbcec6e31ff0d25e7d9a89b81a2ccc8 Mon Sep 17 00:00:00 2001 From: Stanislav Funiak Date: Fri, 26 Nov 2021 18:08:50 +0530 Subject: [PATCH] Multi-root PDL matching using upward traversals. This is commit 4 of 4 for the multi-root matching in PDL, discussed in https://llvm.discourse.group/t/rfc-multi-root-pdl-patterns-for-kernel-matching/4148 (topic flagged for review). This PR integrates the various components (root ordering algorithm, nondeterministic execution of PDL bytecode) to implement multi-root PDL matching. The main idea is for the pattern to specify mulitple candidate roots. The PDL-to-PDLInterp lowering selects one of these roots and "hangs" the pattern from this root, traversing the edges downwards (from operation to its operands) when possible and upwards (from values to its uses) when needed. The root is selected by invoking the optimal matching multiple times, once for each candidate root, and the connectors are determined form the optimal matching. The costs in the directed graph are equal to the number of upward edges that need to be traversed when connecting the given two candidate roots. It can be shown that, for this choice of the cost function, "hanging" the pattern an inner node is no better than from the optimal root. The following three main additions were implemented as a part of this PR: 1. OperationPos predicate has been extended to allow tracing the operation accepting a value (the opposite of operation defining a value). 2. Predicate checking if two values are not equal - this is useful to ensure that we do not traverse the edge back downwards after we traversed it upwards. 3. Function for for building the cost graph among the candidate roots. 4. Updated buildPredicateList, building the predicates optimal branching has been determined. Testing: unit tests (an integration test to follow once the stack of commits has landed) Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D108550 --- mlir/include/mlir/Dialect/PDL/IR/PDLOps.td | 45 ++- .../PDLToPDLInterp/PDLToPDLInterp.cpp | 251 ++++++++---- .../lib/Conversion/PDLToPDLInterp/Predicate.h | 66 +++- .../PDLToPDLInterp/PredicateTree.cpp | 361 +++++++++++++++++- .../Conversion/PDLToPDLInterp/PredicateTree.h | 8 +- mlir/lib/Dialect/PDL/IR/PDL.cpp | 166 +++++--- .../pdl-to-pdl-interp-matcher.mlir | 167 +++++++- mlir/test/Dialect/PDL/invalid.mlir | 61 ++- mlir/test/Dialect/PDL/ops.mlir | 30 ++ 9 files changed, 948 insertions(+), 207 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 9e4301b531c1..96ea659dd1f4 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -427,19 +427,16 @@ def PDL_PatternOp : PDL_Op<"pattern", [ ``` }]; - let arguments = (ins OptionalAttr:$rootKind, - Confined:$benefit, + let arguments = (ins Confined:$benefit, OptionalAttr:$sym_name); let regions = (region SizedRegion<1>:$body); let assemblyFormat = [{ - ($sym_name^)? `:` `benefit` `(` $benefit `)` - (`,` `root` `(` $rootKind^ `)`)? attr-dict-with-keyword $body + ($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $body }]; let builders = [ - OpBuilder<(ins CArg<"Optional", "llvm::None">:$rootKind, - CArg<"Optional", "1">:$benefit, - CArg<"Optional", "llvm::None">:$name)>, + OpBuilder<(ins CArg<"Optional", "1">:$benefit, + CArg<"Optional", "llvm::None">:$name)>, ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// @@ -451,10 +448,6 @@ def PDL_PatternOp : PDL_Op<"pattern", [ /// Returns the rewrite operation of this pattern. RewriteOp getRewriter(); - - /// Return the root operation kind that this pattern matches, or None if - /// there isn't a specific root. - Optional getRootKind(); }]; } @@ -579,19 +572,25 @@ def PDL_ResultsOp : PDL_Op<"results"> { def PDL_RewriteOp : PDL_Op<"rewrite", [ Terminator, HasParent<"pdl::PatternOp">, NoTerminator, NoRegionArguments, - SingleBlock + SingleBlock, AttrSizedOperandSegments ]> { let summary = "Specify the rewrite of a matched pattern"; let description = [{ `pdl.rewrite` operations terminate the region of a `pdl.pattern` and specify - the main rewrite of a `pdl.pattern`, on the specified root operation. The + the main rewrite of a `pdl.pattern`, on the optional root operation. The rewrite is specified either via a string name (`name`) to a native rewrite function, or via the region body. The rewrite region, if specified, must contain a single block. If the rewrite is external it functions similarly to `pdl.apply_native_rewrite`, and takes a set of constant parameters and a set of additional positional values defined within the matcher as arguments. If the rewrite is external, the root operation is - passed to the native function as the first argument. + passed to the native function as the leading arguments. The root operation, + if provided, specifies the starting point in the pattern for the subgraph + isomorphism search. Pattern matching will proceed from this node downward + (towards the defining operation) or upward (towards the users) until all + the operations in the pattern have been matched. If the root is omitted, + the pdl_interp lowering will automatically select the best root of the + pdl.rewrite among all the operations in the pattern. Example: @@ -599,23 +598,31 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [ // Specify an external rewrite function: pdl.rewrite %root with "myExternalRewriter"(%value : !pdl.value) - // Specify the rewrite inline using PDL: + // Specify a rewrite inline using PDL with the given root: pdl.rewrite %root { %op = pdl.operation "foo.op"(%arg0, %arg1) pdl.replace %root with %op } + + // Specify a rewrite inline using PDL, automatically selecting root: + pdl.rewrite { + %op1 = pdl.operation "foo.op"(%arg0, %arg1) + %op2 = pdl.operation "bar.op"(%arg0, %arg1) + pdl.replace %root1 with %op1 + pdl.replace %root2 with %op2 + } ``` }]; - let arguments = (ins PDL_Operation:$root, + let arguments = (ins Optional:$root, OptionalAttr:$name, Variadic:$externalArgs, OptionalAttr:$externalConstParams); let regions = (region AnyRegion:$body); let assemblyFormat = [{ - $root (`with` $name^ ($externalConstParams^)? - (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? - ($body^)? + ($root^)? (`with` $name^ ($externalConstParams^)? + (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? + ($body^)? attr-dict-with-keyword }]; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index d05586ddc40a..f6dfa29a2dd7 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -43,27 +43,27 @@ private: using ValueMapScope = llvm::ScopedHashTableScope; /// Generate interpreter operations for the tree rooted at the given matcher - /// node. - Block *generateMatcher(MatcherNode &node); + /// node, in the specified region. + Block *generateMatcher(MatcherNode &node, Region ®ion); - /// Get or create an access to the provided positional value within the - /// current block. - Value getValueAt(Block *cur, Position *pos); + /// Get or create an access to the provided positional value in the current + /// block. This operation may mutate the provided block pointer if nested + /// regions (i.e., pdl_interp.iterate) are required. + Value getValueAt(Block *¤tBlock, Position *pos); - /// Create an interpreter predicate operation, branching to the provided true - /// and false destinations. - void generatePredicate(Block *currentBlock, Qualifier *question, - Qualifier *answer, Value val, Block *trueDest, - Block *falseDest); + /// Create the interpreter predicate operations. This operation may mutate the + /// provided current block pointer if nested regions (iterates) are required. + void generate(BoolNode *boolNode, Block *¤tBlock, Value val); - /// Create an interpreter switch predicate operation, with a provided default - /// and several case destinations. - void generateSwitch(SwitchNode *switchNode, Block *currentBlock, - Qualifier *question, Value val, Block *defaultDest); + /// Create the interpreter switch / predicate operations, with several case + /// destinations. This operation never mutates the provided current block + /// pointer, because the switch operation does not need Values beyond `val`. + void generate(SwitchNode *switchNode, Block *currentBlock, Value val); - /// Create the interpreter operations to record a successful pattern match. - void generateRecordMatch(Block *currentBlock, Block *nextBlock, - pdl::PatternOp pattern); + /// Create the interpreter operations to record a successful pattern match + /// using the contained root operation. This operation may mutate the current + /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. + void generate(SuccessNode *successNode, Block *¤tBlock); /// Generate a rewriter function for the given pattern operation, and returns /// a reference to that function. @@ -156,7 +156,8 @@ void PatternLowering::lower(ModuleOp module) { // Generate a root matcher node from the provided PDL module. std::unique_ptr root = MatcherNode::generateMatcherTree( module, predicateBuilder, valueToPosition); - Block *firstMatcherBlock = generateMatcher(*root); + Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); + assert(failureBlockStack.empty() && "failed to empty the stack"); // After generation, merged the first matched block into the entry. matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), @@ -164,9 +165,9 @@ void PatternLowering::lower(ModuleOp module) { firstMatcherBlock->erase(); } -Block *PatternLowering::generateMatcher(MatcherNode &node) { +Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { // Push a new scope for the values used by this matcher. - Block *block = matcherFunc.addBlock(); + Block *block = ®ion.emplaceBlock(); ValueMapScope scope(values); // If this is the return node, simply insert the corresponding interpreter @@ -177,66 +178,114 @@ Block *PatternLowering::generateMatcher(MatcherNode &node) { return block; } - // If this node contains a position, get the corresponding value for this - // block. - Position *position = node.getPosition(); - Value val = position ? getValueAt(block, position) : Value(); - // Get the next block in the match sequence. + // This is intentionally executed first, before we get the value for the + // position associated with the node, so that we preserve an "there exist" + // semantics: if getting a value requires an upward traversal (going from a + // value to its consumers), we want to perform the check on all the consumers + // before we pass control to the failure node. std::unique_ptr &failureNode = node.getFailureNode(); - Block *nextBlock; + Block *failureBlock; if (failureNode) { - nextBlock = generateMatcher(*failureNode); - failureBlockStack.push_back(nextBlock); + failureBlock = generateMatcher(*failureNode, region); + failureBlockStack.push_back(failureBlock); } else { assert(!failureBlockStack.empty() && "expected valid failure block"); - nextBlock = failureBlockStack.back(); + failureBlock = failureBlockStack.back(); } + // If this node contains a position, get the corresponding value for this + // block. + Block *currentBlock = block; + Position *position = node.getPosition(); + Value val = position ? getValueAt(currentBlock, position) : Value(); + // If this value corresponds to an operation, record that we are going to use // its location as part of a fused location. bool isOperationValue = val && val.getType().isa(); if (isOperationValue) locOps.insert(val); - // Generate code for a boolean predicate node. - if (auto *boolNode = dyn_cast(&node)) { - auto *child = generateMatcher(*boolNode->getSuccessNode()); - generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, - child, nextBlock); + // Dispatch to the correct method based on derived node type. + TypeSwitch(&node) + .Case( + [&](auto *derivedNode) { generate(derivedNode, currentBlock, val); }) + .Case([&](SuccessNode *successNode) { + generate(successNode, currentBlock); + }); - // Generate code for a switch node. - } else if (auto *switchNode = dyn_cast(&node)) { - generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock); - - // Generate code for a success node. - } else if (auto *successNode = dyn_cast(&node)) { - generateRecordMatch(block, nextBlock, successNode->getPattern()); + // Pop all the failure blocks that were inserted due to nesting of + // pdl_interp.iterate. + while (failureBlockStack.back() != failureBlock) { + failureBlockStack.pop_back(); + assert(!failureBlockStack.empty() && "unable to locate failure block"); } + // Pop the new failure block. if (failureNode) failureBlockStack.pop_back(); + if (isOperationValue) locOps.remove(val); + return block; } -Value PatternLowering::getValueAt(Block *cur, Position *pos) { +Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { if (Value val = values.lookup(pos)) return val; // Get the value for the parent position. - Value parentVal = getValueAt(cur, pos->getParent()); + Value parentVal = getValueAt(currentBlock, pos->getParent()); // TODO: Use a location from the position. Location loc = parentVal.getLoc(); - builder.setInsertionPointToEnd(cur); + builder.setInsertionPointToEnd(currentBlock); Value value; switch (pos->getKind()) { - case Predicates::OperationPos: - value = builder.create( - loc, builder.getType(), parentVal); + case Predicates::OperationPos: { + auto *operationPos = cast(pos); + if (!operationPos->isUpward()) { + // Standard (downward) traversal which directly follows the defining op. + value = builder.create( + loc, builder.getType(), parentVal); + break; + } + + // The first operation retrieves the representative value of a range. + // This applies only when the parent is a range of values. + if (parentVal.getType().isa()) + value = builder.create(loc, parentVal, 0); + else + value = parentVal; + + // The second operation retrieves the users. + value = builder.create(loc, value); + + // The third operation iterates over them. + assert(!failureBlockStack.empty() && "expected valid failure block"); + auto foreach = builder.create( + loc, value, failureBlockStack.back(), /*initLoop=*/true); + value = foreach.getLoopVariable(); + + // Create the success and continuation blocks. + Block *successBlock = builder.createBlock(&foreach.region()); + Block *continueBlock = builder.createBlock(successBlock); + builder.create(loc); + failureBlockStack.push_back(continueBlock); + + // The fourth operation extracts the operand(s) of the user at the specified + // index (which can be None, indicating all operands). + builder.setInsertionPointToStart(&foreach.region().front()); + Value operands = builder.create( + loc, parentVal.getType(), value, operationPos->getIndex()); + + // The fifth operation compares the operands to the parent value / range. + builder.create(loc, parentVal, operands, + successBlock, continueBlock); + currentBlock = successBlock; break; + } case Predicates::OperandPos: { auto *operandPos = cast(pos); value = builder.create( @@ -285,41 +334,60 @@ Value PatternLowering::getValueAt(Block *cur, Position *pos) { llvm_unreachable("Generating unknown Position getter"); break; } + values.insert(pos, value); return value; } -void PatternLowering::generatePredicate(Block *currentBlock, - Qualifier *question, Qualifier *answer, - Value val, Block *trueDest, - Block *falseDest) { - builder.setInsertionPointToEnd(currentBlock); +void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, + Value val) { Location loc = val.getLoc(); + Qualifier *question = boolNode->getQuestion(); + Qualifier *answer = boolNode->getAnswer(); + Region *region = currentBlock->getParent(); + + // Execute the getValue queries first, so that we create success + // matcher in the correct (possibly nested) region. + SmallVector args; + if (auto *equalToQuestion = dyn_cast(question)) { + args = {getValueAt(currentBlock, equalToQuestion->getValue())}; + } else if (auto *cstQuestion = dyn_cast(question)) { + for (Position *position : std::get<1>(cstQuestion->getValue())) + args.push_back(getValueAt(currentBlock, position)); + } + + // Generate the matcher in the current (potentially nested) region + // and get the failure successor. + Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); + Block *failure = failureBlockStack.back(); + + // Finally, create the predicate. + builder.setInsertionPointToEnd(currentBlock); Predicates::Kind kind = question->getKind(); switch (kind) { case Predicates::IsNotNullQuestion: - builder.create(loc, val, trueDest, falseDest); + builder.create(loc, val, success, failure); break; case Predicates::OperationNameQuestion: { auto *opNameAnswer = cast(answer); builder.create( - loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); + loc, val, opNameAnswer->getValue().getStringRef(), success, failure); break; } case Predicates::TypeQuestion: { auto *ans = cast(answer); if (val.getType().isa()) builder.create( - loc, val, ans->getValue().cast(), trueDest, falseDest); + loc, val, ans->getValue().cast(), success, failure); else builder.create( - loc, val, ans->getValue().cast(), trueDest, falseDest); + loc, val, ans->getValue().cast(), success, failure); break; } case Predicates::AttributeQuestion: { auto *ans = cast(answer); builder.create(loc, val, ans->getValue(), - trueDest, falseDest); + success, failure); break; } case Predicates::OperandCountAtLeastQuestion: @@ -327,31 +395,27 @@ void PatternLowering::generatePredicate(Block *currentBlock, builder.create( loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, - trueDest, falseDest); + success, failure); break; case Predicates::ResultCountAtLeastQuestion: case Predicates::ResultCountQuestion: builder.create( loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, - trueDest, falseDest); + success, failure); break; case Predicates::EqualToQuestion: { - auto *equalToQuestion = cast(question); - builder.create( - loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), - trueDest, falseDest); + bool trueAnswer = isa(answer); + builder.create(loc, val, args.front(), + trueAnswer ? success : failure, + trueAnswer ? failure : success); break; } case Predicates::ConstraintQuestion: { - auto *cstQuestion = cast(question); - SmallVector args; - for (Position *position : std::get<1>(cstQuestion->getValue())) - args.push_back(getValueAt(currentBlock, position)); + auto value = cast(question)->getValue(); builder.create( - loc, std::get<0>(cstQuestion->getValue()), args, - std::get<2>(cstQuestion->getValue()).cast(), trueDest, - falseDest); + loc, std::get<0>(value), args, std::get<2>(value).cast(), + success, failure); break; } default: @@ -373,9 +437,12 @@ static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, builder.create(val.getLoc(), val, values, defaultDest, blocks); } -void PatternLowering::generateSwitch(SwitchNode *switchNode, - Block *currentBlock, Qualifier *question, - Value val, Block *defaultDest) { +void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, + Value val) { + Qualifier *question = switchNode->getQuestion(); + Region *region = currentBlock->getParent(); + Block *defaultDest = failureBlockStack.back(); + // If the switch question is not an exact answer, i.e. for the `at_least` // cases, we generate a special block sequence. Predicates::Kind kind = question->getKind(); @@ -407,12 +474,25 @@ void PatternLowering::generateSwitch(SwitchNode *switchNode, // ... // failureBlockStack.push_back(defaultDest); + Location loc = val.getLoc(); for (unsigned idx : sortedChildren) { auto &child = switchNode->getChild(idx); - Block *childBlock = generateMatcher(*child.second); + Block *childBlock = generateMatcher(*child.second, *region); Block *predicateBlock = builder.createBlock(childBlock); - generatePredicate(predicateBlock, question, child.first, val, childBlock, - defaultDest); + builder.setInsertionPointToEnd(predicateBlock); + unsigned ans = cast(child.first)->getValue(); + switch (kind) { + case Predicates::OperandCountAtLeastQuestion: + builder.create( + loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + break; + case Predicates::ResultCountAtLeastQuestion: + builder.create( + loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + break; + default: + llvm_unreachable("Generating invalid AtLeast operation"); + } failureBlockStack.back() = predicateBlock; } Block *firstPredicateBlock = failureBlockStack.pop_back_val(); @@ -426,7 +506,7 @@ void PatternLowering::generateSwitch(SwitchNode *switchNode, // switch. llvm::MapVector children; for (auto &it : switchNode->getChildren()) - children.insert({it.first, generateMatcher(*it.second)}); + children.insert({it.first, generateMatcher(*it.second, *region)}); builder.setInsertionPointToEnd(currentBlock); switch (question->getKind()) { @@ -455,8 +535,10 @@ void PatternLowering::generateSwitch(SwitchNode *switchNode, } } -void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, - pdl::PatternOp pattern) { +void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { + pdl::PatternOp pattern = successNode->getPattern(); + Value root = successNode->getRoot(); + // Generate a rewriter for the pattern this success node represents, and track // any values used from the match region. SmallVector usedMatchValues; @@ -478,14 +560,15 @@ void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, // Grab the root kind if present. StringAttr rootKindAttr; - if (Optional rootKind = pattern.getRootKind()) - rootKindAttr = builder.getStringAttr(*rootKind); + if (pdl::OperationOp rootOp = root.getDefiningOp()) + if (Optional rootKind = rootOp.name()) + rootKindAttr = builder.getStringAttr(*rootKind); builder.setInsertionPointToEnd(currentBlock); builder.create( pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), - nextBlock); + failureBlockStack.back()); } SymbolRefAttr PatternLowering::generateRewriter( @@ -535,8 +618,10 @@ SymbolRefAttr PatternLowering::generateRewriter( // method. pdl::RewriteOp rewriter = pattern.getRewriter(); if (StringAttr rewriteName = rewriter.nameAttr()) { + SmallVector args; + if (rewriter.root()) + args.push_back(mapRewriteValue(rewriter.root())); auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); - SmallVector args(1, mapRewriteValue(rewriter.root())); args.append(mappedArgs.begin(), mappedArgs.end()); builder.create( rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h index e943d55639a5..1b7a3bb98bc5 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -65,8 +65,9 @@ enum Kind : unsigned { // Answers. AttributeAnswer, - TrueAnswer, + FalseAnswer, OperationNameAnswer, + TrueAnswer, TypeAnswer, UnsignedAnswer, }; @@ -216,24 +217,45 @@ struct OperandGroupPosition /// An operation position describes an operation node in the IR. Other position /// kinds are formed with respect to an operation position. -struct OperationPosition : public PredicateBase, - Predicates::OperationPos> { +struct OperationPosition + : public PredicateBase, unsigned>, + Predicates::OperationPos> { + static constexpr unsigned kDown = std::numeric_limits::max(); + explicit OperationPosition(const KeyTy &key) : Base(key) { - parent = key.first; + parent = std::get<0>(key); + } + + /// Returns a hash suitable for the given keytype. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); } /// Gets the root position. static OperationPosition *getRoot(StorageUniquer &uniquer) { - return Base::get(uniquer, nullptr, 0); - } - /// Gets an operation position with the given parent. - static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { - return Base::get(uniquer, parent, parent->getOperationDepth() + 1); + return Base::get(uniquer, nullptr, kDown, 0); } + /// Gets an downward operation position with the given parent. + static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { + return Base::get(uniquer, parent, kDown, parent->getOperationDepth() + 1); + } + + /// Gets an upward operation position with the given parent and operand. + static OperationPosition *get(StorageUniquer &uniquer, Position *parent, + Optional operand) { + return Base::get(uniquer, parent, operand, parent->getOperationDepth() + 1); + } + + /// Returns the operand index for an upward operation position. + Optional getIndex() const { return std::get<1>(key); } + + /// Returns if this operation position is upward, accepting an input. + bool isUpward() const { return getIndex().getValueOr(0) != kDown; } + /// Returns the depth of this position. - unsigned getDepth() const { return key.second; } + unsigned getDepth() const { return std::get<2>(key); } /// Returns if this operation position corresponds to the root. bool isRoot() const { return getDepth() == 0; } @@ -346,6 +368,12 @@ struct TrueAnswer using Base::Base; }; +/// An Answer representing a boolean 'false' value. +struct FalseAnswer + : PredicateBase { + using Base::Base; +}; + /// An Answer representing a `Type` value. The value is stored as either a /// TypeAttr, or an ArrayAttr of TypeAttr. struct TypeAnswer : public PredicateBase(); registerParametricStorageType(); registerParametricStorageType(); + registerSingletonStorageType(); registerSingletonStorageType(); // Register the types of Answers with the uniquer. @@ -485,6 +514,14 @@ public: return OperationPosition::get(uniquer, p); } + /// Returns the position of operation using the value at the given index. + OperationPosition *getUsersOp(Position *p, Optional operand) { + assert((isa(p)) && + "expected result position"); + return OperationPosition::get(uniquer, p, operand); + } + /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); @@ -536,11 +573,16 @@ public: AttributeAnswer::get(uniquer, attr)}; } - /// Create a predicate comparing two values. + /// Create a predicate checking if two values are equal. Predicate getEqualTo(Position *pos) { return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; } + /// Create a predicate checking if two values are not equal. + Predicate getNotEqualTo(Position *pos) { + return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)}; + } + /// Create a predicate that applies a generic constraint. Predicate getConstraint(StringRef name, ArrayRef pos, Attribute params) { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 3061d464fe8b..eef8f3afd5bc 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -7,12 +7,19 @@ //===----------------------------------------------------------------------===// #include "PredicateTree.h" +#include "RootOrdering.h" + #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "pdl-predicate-tree" using namespace mlir; using namespace mlir::pdl_to_pdl_interp; @@ -102,7 +109,8 @@ static void getOperandTreePredicates(std::vector &predList, static void getTreePredicates(std::vector &predList, Value val, PredicateBuilder &builder, DenseMap &inputs, - OperationPosition *pos) { + OperationPosition *pos, + Optional ignoreOperand = llvm::None) { assert(val.getType().isa() && "expected operation"); pdl::OperationOp op = cast(val.getDefiningOp()); OperationPosition *opPos = cast(pos); @@ -158,6 +166,11 @@ static void getTreePredicates(std::vector &predList, bool isVariadic = operandIt.value().getType().isa(); foundVariableLength |= isVariadic; + // Ignore the specified operand, usually because this position was + // visited in an upward traversal via an iterative choice. + if (ignoreOperand && *ignoreOperand == operandIt.index()) + continue; + Position *pos = foundVariableLength ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) @@ -300,15 +313,302 @@ static void getNonTreePredicates(pdl::PatternOp pattern, } } +namespace { + +/// An op accepting a value at an optional index. +struct OpIndex { + Value parent; + Optional index; +}; + +/// The parent and operand index of each operation for each root, stored +/// as a nested map [root][operation]. +using ParentMaps = DenseMap>; + +} // namespace + +/// Given a pattern, determines the set of roots present in this pattern. +/// These are the operations whose results are not consumed by other operations. +static SmallVector detectRoots(pdl::PatternOp pattern) { + // First, collect all the operations that are used as operands + // to other operations. These are not roots by default. + DenseSet used; + for (auto operationOp : pattern.body().getOps()) { + for (Value operand : operationOp.operands()) + TypeSwitch(operand.getDefiningOp()) + .Case( + [&used](auto resultOp) { used.insert(resultOp.parent()); }); + } + + // Remove the specified root from the use set, so that we can + // always select it as a root, even if it is used by other operations. + if (Value root = pattern.getRewriter().root()) + used.erase(root); + + // Finally, collect all the unused operations. + SmallVector roots; + for (Value operationOp : pattern.body().getOps()) + if (!used.contains(operationOp)) + roots.push_back(operationOp); + + return roots; +} + +/// Given a list of candidate roots, builds the cost graph for connecting them. +/// The graph is formed by traversing the DAG of operations starting from each +/// root and marking the depth of each connector value (operand). Then we join +/// the candidate roots based on the common connector values, taking the one +/// with the minimum depth. Along the way, we compute, for each candidate root, +/// a mapping from each operation (in the DAG underneath this root) to its +/// parent operation and the corresponding operand index. +static void buildCostGraph(ArrayRef roots, RootOrderingGraph &graph, + ParentMaps &parentMaps) { + + // The entry of a queue. The entry consists of the following items: + // * the value in the DAG underneath the root; + // * the parent of the value; + // * the operand index of the value in its parent; + // * the depth of the visited value. + struct Entry { + Entry(Value value, Value parent, Optional index, unsigned depth) + : value(value), parent(parent), index(index), depth(depth) {} + + Value value; + Value parent; + Optional index; + unsigned depth; + }; + + // A root of a value and its depth (distance from root to the value). + struct RootDepth { + Value root; + unsigned depth = 0; + }; + + // Map from candidate connector values to their roots and depths. Using a + // small vector with 1 entry because most values belong to a single root. + llvm::MapVector> connectorsRootsDepths; + + // Perform a breadth-first traversal of the op DAG rooted at each root. + for (Value root : roots) { + // The queue of visited values. A value may be present multiple times in + // the queue, for multiple parents. We only accept the first occurrence, + // which is guaranteed to have the lowest depth. + std::queue toVisit; + toVisit.emplace(root, Value(), 0, 0); + + // The map from value to its parent for the current root. + DenseMap &parentMap = parentMaps[root]; + + while (!toVisit.empty()) { + Entry entry = toVisit.front(); + toVisit.pop(); + // Skip if already visited. + if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second) + continue; + + // Mark the root and depth of the value. + connectorsRootsDepths[entry.value].push_back({root, entry.depth}); + + // Traverse the operands of an operation and result ops. + // We intentionally do not traverse attributes and types, because those + // are expensive to join on. + TypeSwitch(entry.value.getDefiningOp()) + .Case([&](auto operationOp) { + OperandRange operands = operationOp.operands(); + // Special case when we pass all the operands in one range. + // For those, the index is empty. + if (operands.size() == 1 && + operands[0].getType().isa()) { + toVisit.emplace(operands[0], entry.value, llvm::None, + entry.depth + 1); + return; + } + + // Default case: visit all the operands. + for (auto p : llvm::enumerate(operationOp.operands())) + toVisit.emplace(p.value(), entry.value, p.index(), + entry.depth + 1); + }) + .Case([&](auto resultOp) { + toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(), + entry.depth); + }); + } + } + + // Now build the cost graph. + // This is simply a minimum over all depths for the target root. + unsigned nextID = 0; + for (const auto &connectorRootsDepths : connectorsRootsDepths) { + Value value = connectorRootsDepths.first; + ArrayRef rootsDepths = connectorRootsDepths.second; + // If there is only one root for this value, this will not trigger + // any edges in the cost graph (a perf optimization). + if (rootsDepths.size() == 1) + continue; + + for (const RootDepth &p : rootsDepths) { + for (const RootDepth &q : rootsDepths) { + if (&p == &q) + continue; + // Insert or retrieve the property of edge from p to q. + RootOrderingCost &cost = graph[q.root][p.root]; + if (!cost.connector /* new edge */ || cost.cost.first > q.depth) { + if (!cost.connector) + cost.cost.second = nextID++; + cost.cost.first = q.depth; + cost.connector = value; + } + } + } + } + + assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && + "the pattern contains a candidate root disconnected from the others"); +} + +/// Visit a node during upward traversal. +void visitUpward(std::vector &predList, OpIndex opIndex, + PredicateBuilder &builder, + DenseMap &valueToPosition, Position *&pos, + bool &first) { + Value value = opIndex.parent; + TypeSwitch(value.getDefiningOp()) + .Case([&](auto operationOp) { + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index); + + // Guard against traversing back to where we came from. + if (first) { + Position *parent = pos->getParent(); + predList.emplace_back(opPos, builder.getNotEqualTo(parent)); + first = false; + } + + // Guard against duplicate upward visits. These are not possible, + // because if this value was already visited, it would have been + // cheaper to start the traversal at this value rather than at the + // `connector`, violating the optimality of our spanning tree. + bool inserted = valueToPosition.try_emplace(value, opPos).second; + assert(inserted && "duplicate upward visit"); + + // Obtain the tree predicates at the current value. + getTreePredicates(predList, value, builder, valueToPosition, opPos, + opIndex.index); + + // Update the position + pos = opPos; + }) + .Case([&](auto resultOp) { + // Traverse up an individual result. + auto *opPos = dyn_cast(pos); + assert(opPos && "operations and results must be interleaved"); + pos = builder.getResult(opPos, *opIndex.index); + }) + .Case([&](auto resultOp) { + // Traverse up a group of results. + auto *opPos = dyn_cast(pos); + assert(opPos && "operations and results must be interleaved"); + bool isVariadic = value.getType().isa(); + if (opIndex.index) + pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); + else + pos = builder.getAllResults(opPos); + }); +} + /// Given a pattern operation, build the set of matcher predicates necessary to /// match this pattern. -static void buildPredicateList(pdl::PatternOp pattern, - PredicateBuilder &builder, - std::vector &predList, - DenseMap &valueToPosition) { - getTreePredicates(predList, pattern.getRewriter().root(), builder, - valueToPosition, builder.getRoot()); +static Value buildPredicateList(pdl::PatternOp pattern, + PredicateBuilder &builder, + std::vector &predList, + DenseMap &valueToPosition) { + SmallVector roots = detectRoots(pattern); + + // Build the root ordering graph and compute the parent maps. + RootOrderingGraph graph; + ParentMaps parentMaps; + buildCostGraph(roots, graph, parentMaps); + LLVM_DEBUG({ + llvm::dbgs() << "Graph:\n"; + for (auto &target : graph) { + llvm::dbgs() << " * " << target.first << "\n"; + for (auto &source : target.second) { + RootOrderingCost c = source.second; + llvm::dbgs() << " <- " << source.first << ": " << c.cost.first + << ":" << c.cost.second << " via " << c.connector.getLoc() + << "\n"; + } + } + }); + + // Solve the optimal branching problem for each candidate root, or use the + // provided one. + Value bestRoot = pattern.getRewriter().root(); + OptimalBranching::EdgeList bestEdges; + if (!bestRoot) { + unsigned bestCost = 0; + LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); + for (Value root : roots) { + OptimalBranching solver(graph, root); + unsigned cost = solver.solve(); + LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); + if (!bestRoot || bestCost > cost) { + bestCost = cost; + bestRoot = root; + bestEdges = solver.preOrderTraversal(roots); + } + } + } else { + OptimalBranching solver(graph, bestRoot); + solver.solve(); + bestEdges = solver.preOrderTraversal(roots); + } + + LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); + + // The best root is the starting point for the traversal. Get the tree + // predicates for the DAG rooted at bestRoot. + getTreePredicates(predList, bestRoot, builder, valueToPosition, + builder.getRoot()); + + // Traverse the selected optimal branching. For all edges in order, traverse + // up starting from the connector, until the candidate root is reached, and + // call getTreePredicates at every node along the way. + for (const std::pair &edge : bestEdges) { + Value target = edge.first; + Value source = edge.second; + + // Check if we already visited the target root. This happens in two cases: + // 1) the initial root (bestRoot); + // 2) a root that is dominated by (contained in the subtree rooted at) an + // already visited root. + if (valueToPosition.count(target)) + continue; + + // Determine the connector. + Value connector = graph[target][source].connector; + assert(connector && "invalid edge"); + LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); + DenseMap parentMap = parentMaps.lookup(target); + Position *pos = valueToPosition.lookup(connector); + assert(pos && "The value has not been traversed yet"); + bool first = true; + + // Traverse from the connector upwards towards the target root. + for (Value value = connector; value != target;) { + OpIndex opIndex = parentMap.lookup(value); + assert(opIndex.parent && "missing parent"); + visitUpward(predList, opIndex, builder, valueToPosition, pos, first); + value = opIndex.parent; + } + } + getNonTreePredicates(pattern, predList, builder, valueToPosition); + + return bestRoot; } //===----------------------------------------------------------------------===// @@ -382,9 +682,11 @@ struct OrderedPredicateDenseInfo { /// This class wraps a set of ordered predicates that are used within a specific /// pattern operation. struct OrderedPredicateList { - OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} + OrderedPredicateList(pdl::PatternOp pattern, Value root) + : pattern(pattern), root(root) {} pdl::PatternOp pattern; + Value root; DenseSet predicates; }; } // end anonymous namespace @@ -421,7 +723,8 @@ static void propagatePattern(std::unique_ptr &node, std::vector::iterator end) { if (current == end) { // We've hit the end of a pattern, so create a successful result node. - node = std::make_unique(list.pattern, std::move(node)); + node = + std::make_unique(list.pattern, list.root, std::move(node)); // If the pattern doesn't contain this predicate, ignore it. } else if (list.predicates.find(*current) == list.predicates.end()) { @@ -489,22 +792,37 @@ static void insertExitNode(std::unique_ptr *root) { std::unique_ptr MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap &valueToPosition) { - // Collect the set of predicates contained within the pattern operations of - // the module. - SmallVector>, 16> - patternsAndPredicates; + // The set of predicates contained within the pattern operations of the + // module. + struct PatternPredicates { + PatternPredicates(pdl::PatternOp pattern, Value root, + std::vector predicates) + : pattern(pattern), root(root), predicates(std::move(predicates)) {} + + /// A pattern. + pdl::PatternOp pattern; + + /// A root of the pattern chosen among the candidate roots in pdl.rewrite. + Value root; + + /// The extracted predicates for this pattern and root. + std::vector predicates; + }; + + SmallVector patternsAndPredicates; for (pdl::PatternOp pattern : module.getOps()) { std::vector predicateList; - buildPredicateList(pattern, builder, predicateList, valueToPosition); - patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); + Value root = + buildPredicateList(pattern, builder, predicateList, valueToPosition); + patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); } // Associate a pattern result with each unique predicate. DenseSet uniqued; for (auto &patternAndPredList : patternsAndPredicates) { - for (auto &predicate : patternAndPredList.second) { + for (auto &predicate : patternAndPredList.predicates) { auto it = uniqued.insert(predicate); - it.first->patternToAnswer.try_emplace(patternAndPredList.first, + it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, predicate.answer); } } @@ -513,8 +831,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, std::vector lists; lists.reserve(patternsAndPredicates.size()); for (auto &patternAndPredList : patternsAndPredicates) { - OrderedPredicateList list(patternAndPredList.first); - for (auto &predicate : patternAndPredList.second) { + OrderedPredicateList list(patternAndPredList.pattern, + patternAndPredList.root); + for (auto &predicate : patternAndPredList.predicates) { OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); list.predicates.insert(orderedPredicate); @@ -580,11 +899,11 @@ BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, // SuccessNode //===----------------------------------------------------------------------===// -SuccessNode::SuccessNode(pdl::PatternOp pattern, +SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr failureNode) : MatcherNode(TypeID::get(), /*position=*/nullptr, /*question=*/nullptr, std::move(failureNode)), - pattern(pattern) {} + pattern(pattern), root(root) {} //===----------------------------------------------------------------------===// // SwitchNode diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h index ac2fa98d7c7b..796eb762c97d 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h @@ -152,7 +152,7 @@ struct ExitNode : public MatcherNode { /// matched. This does not terminate the matcher, as there may be multiple /// successful matches. struct SuccessNode : public MatcherNode { - explicit SuccessNode(pdl::PatternOp pattern, + explicit SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr failureNode); /// Returns if the given matcher node is an instance of this class, used to @@ -164,10 +164,16 @@ struct SuccessNode : public MatcherNode { /// Return the high level pattern operation that is matched with this node. pdl::PatternOp getPattern() const { return pattern; } + /// Return the chosen root of the pattern. + Value getRoot() const { return root; } + private: /// The high level pattern operation that was successfully matched with this /// node. pdl::PatternOp pattern; + + /// The chosen root of the pattern. + Value root; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index 856297444be1..81a8f60610bc 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -11,7 +11,8 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/InferTypeOpInterface.h" -#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::pdl; @@ -34,41 +35,55 @@ void PDLDialect::initialize() { // PDL Operations //===----------------------------------------------------------------------===// -/// Returns true if the given operation is used by a "binding" pdl operation -/// within the main matcher body of a `pdl.pattern`. -static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) { - for (OpOperand &use : op->getUses()) { - Operation *user = use.getOwner(); - if (user->getBlock() != matcherBlock) - continue; - if (isa(user)) - return true; - // Only the first operand of RewriteOp may be bound to, i.e. the root - // operation of the pattern. - if (isa(user) && use.getOperandNumber() == 0) - return true; +/// Returns true if the given operation is used by a "binding" pdl operation. +static bool hasBindingUse(Operation *op) { + for (Operation *user : op->getUsers()) // A result by itself is not binding, it must also be bound. - if (isa(user) && - hasBindingUseInMatcher(user, matcherBlock)) + if (!isa(user) || hasBindingUse(user)) return true; - } return false; } -/// Returns success if the given operation is used by a "binding" pdl operation -/// within the main matcher body of a `pdl.pattern`. On failure, emits an error -/// with the given context message. -static LogicalResult -verifyHasBindingUseInMatcher(Operation *op, - StringRef bindableContextStr = "`pdl.operation`") { - // If the pattern is not a pattern, there is nothing to do. +/// Returns success if the given operation is not in the main matcher body or +/// is used by a "binding" operation. On failure, emits an error. +static LogicalResult verifyHasBindingUse(Operation *op) { + // If the parent is not a pattern, there is nothing to do. if (!isa(op->getParentOp())) return success(); - if (hasBindingUseInMatcher(op, op->getBlock())) + if (hasBindingUse(op)) return success(); - return op->emitOpError() - << "expected a bindable (i.e. " << bindableContextStr - << ") user when defined in the matcher body of a `pdl.pattern`"; + return op->emitOpError( + "expected a bindable user when defined in the matcher body of a " + "`pdl.pattern`"); +} + +/// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) +/// connected to the given operation. +static void visit(Operation *op, DenseSet &visited) { + // If the parent is not a pattern, there is nothing to do. + if (!isa(op->getParentOp()) || isa(op)) + return; + + // Ignore if already visited. + if (visited.contains(op)) + return; + + // Mark as visited. + visited.insert(op); + + // Traverse the operands / parent. + TypeSwitch(op) + .Case([&visited](auto operation) { + for (Value operand : operation.operands()) + visit(operand.getDefiningOp(), visited); + }) + .Case([&visited](auto result) { + visit(result.parent().getDefiningOp(), visited); + }); + + // Traverse the users. + for (Operation *user : op->getUsers()) + visit(user, visited); } //===----------------------------------------------------------------------===// @@ -104,24 +119,20 @@ static LogicalResult verify(AttributeOp op) { "`pdl.rewrite`"); if (attrValue && attrType) return op.emitOpError("expected only one of [`type`, `value`] to be set"); - return verifyHasBindingUseInMatcher(op); + return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // pdl::OperandOp //===----------------------------------------------------------------------===// -static LogicalResult verify(OperandOp op) { - return verifyHasBindingUseInMatcher(op); -} +static LogicalResult verify(OperandOp op) { return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // pdl::OperandsOp //===----------------------------------------------------------------------===// -static LogicalResult verify(OperandsOp op) { - return verifyHasBindingUseInMatcher(op); -} +static LogicalResult verify(OperandsOp op) { return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // pdl::OperationOp @@ -237,7 +248,7 @@ static LogicalResult verify(OperationOp op) { return failure(); } - return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`"); + return verifyHasBindingUse(op); } bool OperationOp::hasTypeInference() { @@ -256,15 +267,16 @@ bool OperationOp::hasTypeInference() { static LogicalResult verify(PatternOp pattern) { Region &body = pattern.body(); - auto *term = body.front().getTerminator(); - if (!isa(term)) { + Operation *term = body.front().getTerminator(); + auto rewrite_op = dyn_cast(term); + if (!rewrite_op) { return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") .attachNote(term->getLoc()) .append("see terminator defined here"); } - // Check that all values defined in the top-level pattern are referenced at - // least once in the source tree. + // Check that all values defined in the top-level pattern belong to the PDL + // dialect. WalkResult result = body.walk([&](Operation *op) -> WalkResult { if (!isa_and_nonnull(op->getDialect())) { pattern @@ -275,15 +287,61 @@ static LogicalResult verify(PatternOp pattern) { } return WalkResult::advance(); }); - return failure(result.wasInterrupted()); + if (result.wasInterrupted()) + return failure(); + + // Check that there is at least one operation. + if (body.front().getOps().empty()) + return pattern.emitOpError( + "the pattern must contain at least one `pdl.operation`"); + + // Determine if the operations within the pdl.pattern form a connected + // component. This is determined by starting the search from the first + // operand/result/operation and visiting their users / parents / operands. + // We limit our attention to operations that have a user in pdl.rewrite, + // those that do not will be detected via other means (expected bindable + // user). + bool first = true; + DenseSet visited; + for (Operation &op : body.front()) { + // The following are the operations forming the connected component. + if (!isa(op)) + continue; + + // Determine if the operation has a user in `pdl.rewrite`. + bool hasUserInRewrite = false; + for (Operation *user : op.getUsers()) { + Region *region = user->getParentRegion(); + if (isa(user) || + (region && isa(region->getParentOp()))) { + hasUserInRewrite = true; + break; + } + } + + // If the operation does not have a user in `pdl.rewrite`, ignore it. + if (!hasUserInRewrite) + continue; + + if (first) { + // For the first operation, invoke visit. + visit(&op, visited); + first = false; + } else if (!visited.count(&op)) { + // For the subsequent operations, check if already visited. + return pattern + .emitOpError("the operations must form a connected component") + .attachNote(op.getLoc()) + .append("see a disconnected value / operation here"); + } + } + + return success(); } void PatternOp::build(OpBuilder &builder, OperationState &state, - Optional rootKind, Optional benefit, - Optional name) { - build(builder, state, - rootKind ? builder.getStringAttr(*rootKind) : StringAttr(), - builder.getI16IntegerAttr(benefit ? *benefit : 0), + Optional benefit, Optional name) { + build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0), name ? builder.getStringAttr(*name) : StringAttr()); state.regions[0]->emplaceBlock(); } @@ -293,13 +351,6 @@ RewriteOp PatternOp::getRewriter() { return cast(body().front().getTerminator()); } -/// Return the root operation kind that this pattern matches, or None if -/// there isn't a specific root. -Optional PatternOp::getRootKind() { - OperationOp rootOp = cast(getRewriter().root().getDefiningOp()); - return rootOp.name(); -} - //===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// @@ -380,18 +431,13 @@ static LogicalResult verify(RewriteOp op) { // pdl::TypeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypeOp op) { - return verifyHasBindingUseInMatcher( - op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`"); -} +static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // pdl::TypesOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypesOp op) { - return verifyHasBindingUseInMatcher(op, "`pdl.operands`, or `pdl.operation`"); -} +static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index 0af77a24efb4..0efcd60945a7 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -384,7 +384,7 @@ module @switch_result_count_at_least { // ----- // CHECK-LABEL: module @predicate_ordering -module @predicate_ordering { +module @predicate_ordering { // Check that the result is checked for null first, before applying the // constraint. The null check is prevalent in both patterns, so should be // prioritized first. @@ -408,3 +408,168 @@ module @predicate_ordering { pdl.rewrite %apply with "rewriter" } } + + +// ----- + +// CHECK-LABEL: module @multi_root +module @multi_root { + // Check the lowering of a simple two-root pattern. + // This checks that we correctly generate the pdl_interp.choose_op operation + // and tie the break between %root1 and %root2 in favor of %root1. + + // CHECK: func @matcher(%[[ROOT1:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL1:.*]] = pdl_interp.get_operand 0 of %[[ROOT1]] + // CHECK-DAG: %[[OP1:.*]] = pdl_interp.get_defining_op of %[[VAL1]] + // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL1]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[OPS]] + // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands 0 of %[[ROOT2]] + // CHECK-DAG: pdl_interp.are_equal %[[VAL1]], %[[OPERANDS]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]] + // CHECK-DAG: pdl_interp.continue + // CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 1 of %[[ROOT2]] + // CHECK-DAG: %[[OP2:.*]] = pdl_interp.get_defining_op of %[[VAL2]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP1]] : !pdl.operation -> ^{{.*}}, ^[[CONTINUE]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP2]] : !pdl.operation + // CHECK-DAG: pdl_interp.is_not_null %[[VAL1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[VAL2]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] : !pdl.operation + // CHECK-DAG: pdl_interp.are_equal %[[ROOT2]], %[[ROOT1]] : !pdl.operation -> ^[[CONTINUE]] + + pdl.pattern @rewrite_multi_root : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite %root1 with "rewriter"(%root2 : !pdl.operation) + } +} + + +// ----- + +// CHECK-LABEL: module @overlapping_roots +module @overlapping_roots { + // Check the lowering of a degenerate two-root pattern, where one root + // is in the subtree rooted at another. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VAL]] + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[OP]] + // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[OP]] + // CHECK-DAG: pdl_interp.is_not_null %[[VAL]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value + + pdl.pattern @rewrite_overlapping_roots : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type) + %val = pdl.result 0 of %op + %root = pdl.operation(%val : !pdl.value) + pdl.rewrite with "rewriter"(%root : !pdl.operation) + } +} + +// ----- + +// CHECK-LABEL: module @force_overlapped_root +module @force_overlapped_root { + // Check the lowering of a degenerate two-root pattern, where one root + // is in the subtree rooted at another, and we are forced to use this + // root as the root of the search tree. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_result 0 of %[[ROOT]] + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 2 + // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT]] is 1 + // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value + // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[OP:.*]] : !pdl.operation in %[[OPS]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation + // CHECK-DAG: pdl_interp.check_operand_count of %[[OP]] is 1 + + pdl.pattern @rewrite_forced_overlapped_root : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %root = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type) + %val = pdl.result 0 of %root + %op = pdl.operation(%val : !pdl.value) + pdl.rewrite %root with "rewriter"(%op : !pdl.operation) + } +} + +// ----- + +// CHECK-LABEL: module @variadic_results_all +module @variadic_results_all { + // Check the correct lowering when using all results of an operation + // and passing it them as operands to another operation. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 0 + // CHECK-DAG: %[[VALS:.*]] = pdl_interp.get_results of %[[ROOT]] : !pdl.range + // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS]] + // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[OP:.*]] : !pdl.operation in %[[OPS]] + // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands of %[[OP]] + // CHECK-DAG pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] -> ^{{.*}}, ^[[CONTINUE:.*]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] + // CHECK-DAG: pdl_interp.check_result_count of %[[OP]] is 0 + pdl.pattern @variadic_results_all : benefit(1) { + %types = pdl.types + %root = pdl.operation -> (%types : !pdl.range) + %vals = pdl.results of %root + %op = pdl.operation(%vals : !pdl.range) + pdl.rewrite %root with "rewriter"(%op : !pdl.operation) + } +} + +// ----- + +// CHECK-LABEL: module @variadic_results_at +module @variadic_results_at { + // Check the correct lowering when using selected results of an operation + // and passing it them as an operand to another operation. + + // CHECK: func @matcher(%[[ROOT1:.*]]: !pdl.operation) + // CHECK-DAG: %[[VALS:.*]] = pdl_interp.get_operands 0 of %[[ROOT1]] : !pdl.range + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VALS]] : !pdl.range + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT1]] is at_least 1 + // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT1]] is 0 + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_operands 1 of %[[ROOT1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[VAL]] + // CHECK-DAG: pdl_interp.is_not_null %[[VALS]] + // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS]] + // CHECK-DAG: %[[ROOTS2:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[ROOTS2]] { + // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands 1 of %[[ROOT2]] + // CHECK-DAG: pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] : !pdl.range -> ^{{.*}}, ^[[CONTINUE:.*]] + // CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT2]] is at_least 1 + // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT2]] is 0 + // CHECK-DAG: pdl_interp.check_operand_count of %[[OP]] is 0 + // CHECK-DAG: pdl_interp.check_result_count of %[[OP]] is at_least 1 + pdl.pattern @variadic_results_at : benefit(1) { + %type = pdl.type + %types = pdl.types + %val = pdl.operand + %op = pdl.operation -> (%types, %type : !pdl.range, !pdl.type) + %vals = pdl.results 0 of %op -> !pdl.range + %root1 = pdl.operation(%vals, %val : !pdl.range, !pdl.value) + %root2 = pdl.operation(%val, %vals : !pdl.value, !pdl.range) + pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation) + } +} diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir index 5c05e1a89d3f..17b7370292b0 100644 --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -67,7 +67,7 @@ pdl.pattern : benefit(1) { // ----- pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.attribute %op = pdl.operation "foo.op" @@ -81,7 +81,7 @@ pdl.pattern : benefit(1) { //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.operand %op = pdl.operation "foo.op" @@ -95,7 +95,7 @@ pdl.pattern : benefit(1) { //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.operands %op = pdl.operation "foo.op" @@ -143,7 +143,7 @@ pdl.pattern : benefit(1) { // ----- pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operation` or `pdl.rewrite`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.operation "foo.op" %op = pdl.operation "foo.op" @@ -164,6 +164,12 @@ pdl.pattern : benefit(1) { // ----- +// expected-error@below {{the pattern must contain at least one `pdl.operation`}} +pdl.pattern : benefit(1) { + pdl.rewrite with "foo" +} + +// ----- // expected-error@below {{expected only `pdl` operations within the pattern body}} pdl.pattern : benefit(1) { // expected-note@below {{see non-`pdl` operation defined here}} @@ -173,6 +179,32 @@ pdl.pattern : benefit(1) { pdl.rewrite %root with "foo" } +// ----- +// expected-error@below {{the operations must form a connected component}} +pdl.pattern : benefit(1) { + %op1 = pdl.operation "foo.op" + %op2 = pdl.operation "bar.op" + // expected-note@below {{see a disconnected value / operation here}} + %val = pdl.result 0 of %op2 + pdl.rewrite %op1 with "foo"(%val : !pdl.value) +} + +// ----- +// expected-error@below {{the operations must form a connected component}} +pdl.pattern : benefit(1) { + %type = pdl.type + %op1 = pdl.operation "foo.op" -> (%type : !pdl.type) + %val = pdl.result 0 of %op1 + %op2 = pdl.operation "bar.op"(%val : !pdl.value) + // expected-note@below {{see a disconnected value / operation here}} + %op3 = pdl.operation "baz.op" + pdl.rewrite { + pdl.erase %op1 + pdl.erase %op2 + pdl.erase %op3 + } +} + // ----- pdl.pattern : benefit(1) { @@ -212,7 +244,9 @@ pdl.pattern : benefit(1) { %op = pdl.operation "foo.op" // expected-error@below {{expected rewrite region to be non-empty if external name is not specified}} - "pdl.rewrite"(%op) ({}) : (!pdl.operation) -> () + "pdl.rewrite"(%op) ({}) { + operand_segment_sizes = dense<[1,0]> : vector<2xi32> + } : (!pdl.operation) -> () } // ----- @@ -223,7 +257,9 @@ pdl.pattern : benefit(1) { // expected-error@below {{expected no external arguments when the rewrite is specified inline}} "pdl.rewrite"(%op, %op) ({ ^bb1: - }) : (!pdl.operation, !pdl.operation) -> () + }) { + operand_segment_sizes = dense<1> : vector<2xi32> + }: (!pdl.operation, !pdl.operation) -> () } // ----- @@ -234,7 +270,9 @@ pdl.pattern : benefit(1) { // expected-error@below {{expected no external constant parameters when the rewrite is specified inline}} "pdl.rewrite"(%op) ({ ^bb1: - }) {externalConstParams = []} : (!pdl.operation) -> () + }) { + operand_segment_sizes = dense<[1,0]> : vector<2xi32>, + externalConstParams = []} : (!pdl.operation) -> () } // ----- @@ -245,7 +283,10 @@ pdl.pattern : benefit(1) { // expected-error@below {{expected rewrite region to be empty when rewrite is external}} "pdl.rewrite"(%op) ({ ^bb1: - }) {name = "foo"} : (!pdl.operation) -> () + }) { + name = "foo", + operand_segment_sizes = dense<[1,0]> : vector<2xi32> + } : (!pdl.operation) -> () } // ----- @@ -255,7 +296,7 @@ pdl.pattern : benefit(1) { //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.attribute`, `pdl.operand`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.type %op = pdl.operation "foo.op" @@ -269,7 +310,7 @@ pdl.pattern : benefit(1) { //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operands`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.types %op = pdl.operation "foo.op" diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir index 07e98f9e5868..1993895c0d6d 100644 --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -42,6 +42,36 @@ pdl.pattern @rewrite_with_args_and_params : benefit(1) { // ----- +pdl.pattern @rewrite_multi_root_optimal : benefit(2) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite with "rewriter"["I am param"](%root1, %root2 : !pdl.operation, !pdl.operation) +} + +// ----- + +pdl.pattern @rewrite_multi_root_forced : benefit(2) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite %root1 with "rewriter"["I am param"](%root2 : !pdl.operation) +} + +// ----- + // Check that the result type of an operation within a rewrite can be inferred // from a pdl.replace. pdl.pattern @infer_type_from_operation_replace : benefit(1) {