[PDLL] Add support for single line lambda-like patterns

This allows for defining simple patterns in a single line. The lambda
body of a Pattern expects a single operation rewrite statement:

```
Pattern => replace op<my_dialect.foo>(operands: ValueRange) with operands;
```

Differential Revision: https://reviews.llvm.org/D115835
This commit is contained in:
River Riddle 2022-01-02 03:40:45 +00:00
parent 8cffea061a
commit 3d8b906012
3 changed files with 78 additions and 23 deletions

View File

@ -106,6 +106,10 @@ private:
FailureOr<ast::Decl *> parseTopLevelDecl();
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
FailureOr<ast::CompoundStmt *>
parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
bool expectTerminalSemicolon = true);
FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
FailureOr<ast::Decl *> parsePatternDecl();
LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
@ -547,6 +551,36 @@ FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
return ast::NamedAttributeDecl::create(ctx, name, attrValue);
}
FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
bool expectTerminalSemicolon) {
consumeToken(Token::equal_arrow);
// Parse the single statement of the lambda body.
SMLoc bodyStartLoc = curToken.getStartLoc();
pushDeclScope();
FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
bool failedToParse =
failed(singleStatement) || failed(processStatementFn(*singleStatement));
popDeclScope();
if (failedToParse)
return failure();
SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
}
FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
if (isa<ast::OpRewriteStmt>(statement))
return success();
return emitError(
statement->getLoc(),
"expected Pattern lambda body to contain a single operation "
"rewrite statement, such as `erase`, `replace`, or `rewrite`");
});
}
FailureOr<ast::Decl *> Parser::parsePatternDecl() {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_Pattern);
@ -568,29 +602,37 @@ FailureOr<ast::Decl *> Parser::parsePatternDecl() {
// Parse the pattern body.
ast::CompoundStmt *body;
if (curToken.isNot(Token::l_brace))
return emitError("expected `{` to start pattern body");
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
if (failed(bodyResult))
return failure();
body = *bodyResult;
// Handle a lambda body.
if (curToken.is(Token::equal_arrow)) {
FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
if (failed(bodyResult))
return failure();
body = *bodyResult;
} else {
if (curToken.isNot(Token::l_brace))
return emitError("expected `{` or `=>` to start pattern body");
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
if (failed(bodyResult))
return failure();
body = *bodyResult;
// Verify the body of the pattern.
auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt) {
// Break when we've found the rewrite statement.
if (isa<ast::OpRewriteStmt>(*bodyIt))
break;
}
if (bodyIt == bodyE) {
return emitError(loc,
"expected Pattern body to terminate with an operation "
"rewrite statement, such as `erase`");
}
if (std::next(bodyIt) != bodyE) {
return emitError((*std::next(bodyIt))->getLoc(),
"Pattern body was terminated by an operation "
"rewrite statement, but found trailing statements");
// Verify the body of the pattern.
auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt) {
// Break when we've found the rewrite statement.
if (isa<ast::OpRewriteStmt>(*bodyIt))
break;
}
if (bodyIt == bodyE) {
return emitError(loc,
"expected Pattern body to terminate with an operation "
"rewrite statement, such as `erase`");
}
if (std::next(bodyIt) != bodyE) {
return emitError((*std::next(bodyIt))->getLoc(),
"Pattern body was terminated by an operation "
"rewrite statement, but found trailing statements");
}
}
return createPatternDecl(loc, name, metadata, body);

View File

@ -1,6 +1,6 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// CHECK: expected `{` to start pattern body
// CHECK: expected `{` or `=>` to start pattern body
Pattern }
// -----
@ -27,6 +27,11 @@ Pattern {
// -----
// CHECK: expected Pattern lambda body to contain a single operation rewrite statement, such as `erase`, `replace`, or `rewrite`
Pattern => op<>;
// -----
//===----------------------------------------------------------------------===//
// Metadata
//===----------------------------------------------------------------------===//

View File

@ -23,3 +23,11 @@ Pattern NamedPattern {
Pattern NamedPattern with benefit(10), recursion {
erase _: Op;
}
// -----
// CHECK: Module
// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
// CHECK: `-CompoundStmt
// CHECK: `-EraseStmt
Pattern NamedPattern => erase _: Op;