forked from OSchip/llvm-project
[PDLL] Add support for single line lambda-like patterns
This allows for defining simple patterns in a single line. The lambda body of a Pattern expects a single operation rewrite statement: ``` Pattern => replace op<my_dialect.foo>(operands: ValueRange) with operands; ``` Differential Revision: https://reviews.llvm.org/D115835
This commit is contained in:
parent
8cffea061a
commit
3d8b906012
|
@ -106,6 +106,10 @@ private:
|
|||
|
||||
FailureOr<ast::Decl *> parseTopLevelDecl();
|
||||
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
|
||||
FailureOr<ast::CompoundStmt *>
|
||||
parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
|
||||
bool expectTerminalSemicolon = true);
|
||||
FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
|
||||
FailureOr<ast::Decl *> parsePatternDecl();
|
||||
LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
|
||||
|
||||
|
@ -547,6 +551,36 @@ FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
|
|||
return ast::NamedAttributeDecl::create(ctx, name, attrValue);
|
||||
}
|
||||
|
||||
FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
|
||||
function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
|
||||
bool expectTerminalSemicolon) {
|
||||
consumeToken(Token::equal_arrow);
|
||||
|
||||
// Parse the single statement of the lambda body.
|
||||
SMLoc bodyStartLoc = curToken.getStartLoc();
|
||||
pushDeclScope();
|
||||
FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
|
||||
bool failedToParse =
|
||||
failed(singleStatement) || failed(processStatementFn(*singleStatement));
|
||||
popDeclScope();
|
||||
if (failedToParse)
|
||||
return failure();
|
||||
|
||||
SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
|
||||
return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
|
||||
}
|
||||
|
||||
FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
|
||||
return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
|
||||
if (isa<ast::OpRewriteStmt>(statement))
|
||||
return success();
|
||||
return emitError(
|
||||
statement->getLoc(),
|
||||
"expected Pattern lambda body to contain a single operation "
|
||||
"rewrite statement, such as `erase`, `replace`, or `rewrite`");
|
||||
});
|
||||
}
|
||||
|
||||
FailureOr<ast::Decl *> Parser::parsePatternDecl() {
|
||||
SMRange loc = curToken.getLoc();
|
||||
consumeToken(Token::kw_Pattern);
|
||||
|
@ -568,29 +602,37 @@ FailureOr<ast::Decl *> Parser::parsePatternDecl() {
|
|||
// Parse the pattern body.
|
||||
ast::CompoundStmt *body;
|
||||
|
||||
if (curToken.isNot(Token::l_brace))
|
||||
return emitError("expected `{` to start pattern body");
|
||||
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
|
||||
if (failed(bodyResult))
|
||||
return failure();
|
||||
body = *bodyResult;
|
||||
// Handle a lambda body.
|
||||
if (curToken.is(Token::equal_arrow)) {
|
||||
FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
|
||||
if (failed(bodyResult))
|
||||
return failure();
|
||||
body = *bodyResult;
|
||||
} else {
|
||||
if (curToken.isNot(Token::l_brace))
|
||||
return emitError("expected `{` or `=>` to start pattern body");
|
||||
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
|
||||
if (failed(bodyResult))
|
||||
return failure();
|
||||
body = *bodyResult;
|
||||
|
||||
// Verify the body of the pattern.
|
||||
auto bodyIt = body->begin(), bodyE = body->end();
|
||||
for (; bodyIt != bodyE; ++bodyIt) {
|
||||
// Break when we've found the rewrite statement.
|
||||
if (isa<ast::OpRewriteStmt>(*bodyIt))
|
||||
break;
|
||||
}
|
||||
if (bodyIt == bodyE) {
|
||||
return emitError(loc,
|
||||
"expected Pattern body to terminate with an operation "
|
||||
"rewrite statement, such as `erase`");
|
||||
}
|
||||
if (std::next(bodyIt) != bodyE) {
|
||||
return emitError((*std::next(bodyIt))->getLoc(),
|
||||
"Pattern body was terminated by an operation "
|
||||
"rewrite statement, but found trailing statements");
|
||||
// Verify the body of the pattern.
|
||||
auto bodyIt = body->begin(), bodyE = body->end();
|
||||
for (; bodyIt != bodyE; ++bodyIt) {
|
||||
// Break when we've found the rewrite statement.
|
||||
if (isa<ast::OpRewriteStmt>(*bodyIt))
|
||||
break;
|
||||
}
|
||||
if (bodyIt == bodyE) {
|
||||
return emitError(loc,
|
||||
"expected Pattern body to terminate with an operation "
|
||||
"rewrite statement, such as `erase`");
|
||||
}
|
||||
if (std::next(bodyIt) != bodyE) {
|
||||
return emitError((*std::next(bodyIt))->getLoc(),
|
||||
"Pattern body was terminated by an operation "
|
||||
"rewrite statement, but found trailing statements");
|
||||
}
|
||||
}
|
||||
|
||||
return createPatternDecl(loc, name, metadata, body);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: expected `{` to start pattern body
|
||||
// CHECK: expected `{` or `=>` to start pattern body
|
||||
Pattern }
|
||||
|
||||
// -----
|
||||
|
@ -27,6 +27,11 @@ Pattern {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK: expected Pattern lambda body to contain a single operation rewrite statement, such as `erase`, `replace`, or `rewrite`
|
||||
Pattern => op<>;
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Metadata
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -23,3 +23,11 @@ Pattern NamedPattern {
|
|||
Pattern NamedPattern with benefit(10), recursion {
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
|
||||
// CHECK: `-CompoundStmt
|
||||
// CHECK: `-EraseStmt
|
||||
Pattern NamedPattern => erase _: Op;
|
||||
|
|
Loading…
Reference in New Issue