[PDLL] Add support for single line lambda-like patterns

This allows for defining simple patterns in a single line. The lambda
body of a Pattern expects a single operation rewrite statement:

```
Pattern => replace op<my_dialect.foo>(operands: ValueRange) with operands;
```

Differential Revision: https://reviews.llvm.org/D115835
This commit is contained in:
River Riddle 2022-01-02 03:40:45 +00:00
parent 8cffea061a
commit 3d8b906012
3 changed files with 78 additions and 23 deletions

View File

@ -106,6 +106,10 @@ private:
FailureOr<ast::Decl *> parseTopLevelDecl(); FailureOr<ast::Decl *> parseTopLevelDecl();
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl(); FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
FailureOr<ast::CompoundStmt *>
parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
bool expectTerminalSemicolon = true);
FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
FailureOr<ast::Decl *> parsePatternDecl(); FailureOr<ast::Decl *> parsePatternDecl();
LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
@ -547,6 +551,36 @@ FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
return ast::NamedAttributeDecl::create(ctx, name, attrValue); return ast::NamedAttributeDecl::create(ctx, name, attrValue);
} }
FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
bool expectTerminalSemicolon) {
consumeToken(Token::equal_arrow);
// Parse the single statement of the lambda body.
SMLoc bodyStartLoc = curToken.getStartLoc();
pushDeclScope();
FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
bool failedToParse =
failed(singleStatement) || failed(processStatementFn(*singleStatement));
popDeclScope();
if (failedToParse)
return failure();
SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
}
FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
if (isa<ast::OpRewriteStmt>(statement))
return success();
return emitError(
statement->getLoc(),
"expected Pattern lambda body to contain a single operation "
"rewrite statement, such as `erase`, `replace`, or `rewrite`");
});
}
FailureOr<ast::Decl *> Parser::parsePatternDecl() { FailureOr<ast::Decl *> Parser::parsePatternDecl() {
SMRange loc = curToken.getLoc(); SMRange loc = curToken.getLoc();
consumeToken(Token::kw_Pattern); consumeToken(Token::kw_Pattern);
@ -568,29 +602,37 @@ FailureOr<ast::Decl *> Parser::parsePatternDecl() {
// Parse the pattern body. // Parse the pattern body.
ast::CompoundStmt *body; ast::CompoundStmt *body;
if (curToken.isNot(Token::l_brace)) // Handle a lambda body.
return emitError("expected `{` to start pattern body"); if (curToken.is(Token::equal_arrow)) {
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
if (failed(bodyResult)) if (failed(bodyResult))
return failure(); return failure();
body = *bodyResult; body = *bodyResult;
} else {
if (curToken.isNot(Token::l_brace))
return emitError("expected `{` or `=>` to start pattern body");
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
if (failed(bodyResult))
return failure();
body = *bodyResult;
// Verify the body of the pattern. // Verify the body of the pattern.
auto bodyIt = body->begin(), bodyE = body->end(); auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt) { for (; bodyIt != bodyE; ++bodyIt) {
// Break when we've found the rewrite statement. // Break when we've found the rewrite statement.
if (isa<ast::OpRewriteStmt>(*bodyIt)) if (isa<ast::OpRewriteStmt>(*bodyIt))
break; break;
} }
if (bodyIt == bodyE) { if (bodyIt == bodyE) {
return emitError(loc, return emitError(loc,
"expected Pattern body to terminate with an operation " "expected Pattern body to terminate with an operation "
"rewrite statement, such as `erase`"); "rewrite statement, such as `erase`");
} }
if (std::next(bodyIt) != bodyE) { if (std::next(bodyIt) != bodyE) {
return emitError((*std::next(bodyIt))->getLoc(), return emitError((*std::next(bodyIt))->getLoc(),
"Pattern body was terminated by an operation " "Pattern body was terminated by an operation "
"rewrite statement, but found trailing statements"); "rewrite statement, but found trailing statements");
}
} }
return createPatternDecl(loc, name, metadata, body); return createPatternDecl(loc, name, metadata, body);

View File

@ -1,6 +1,6 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s // RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// CHECK: expected `{` to start pattern body // CHECK: expected `{` or `=>` to start pattern body
Pattern } Pattern }
// ----- // -----
@ -27,6 +27,11 @@ Pattern {
// ----- // -----
// CHECK: expected Pattern lambda body to contain a single operation rewrite statement, such as `erase`, `replace`, or `rewrite`
Pattern => op<>;
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Metadata // Metadata
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -23,3 +23,11 @@ Pattern NamedPattern {
Pattern NamedPattern with benefit(10), recursion { Pattern NamedPattern with benefit(10), recursion {
erase _: Op; erase _: Op;
} }
// -----
// CHECK: Module
// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
// CHECK: `-CompoundStmt
// CHECK: `-EraseStmt
Pattern NamedPattern => erase _: Op;