[mlir][PDL] Move the formats for PatternOp and RewriteOp to the declarative form.

This is possible now that the declarative assembly form supports regions.

Differential Revision: https://reviews.llvm.org/D86830
This commit is contained in:
River Riddle 2020-08-31 12:34:04 -07:00
parent eaeadce9bd
commit 2481846a30
2 changed files with 13 additions and 104 deletions

View File

@ -296,7 +296,7 @@ def PDL_PatternOp : PDL_Op<"pattern", [IsolatedFromAbove, Symbol]> {
%resultType = pdl.type %resultType = pdl.type
%inputOperand = pdl.input %inputOperand = pdl.input
%root, %results = pdl.operation "foo.op"(%inputOperand) -> (%resultType) %root, %results = pdl.operation "foo.op"(%inputOperand) -> (%resultType)
pdl.rewrite(%root) { pdl.rewrite %root {
pdl.replace %root with (%inputOperand) pdl.replace %root with (%inputOperand)
} }
} }
@ -305,9 +305,13 @@ def PDL_PatternOp : PDL_Op<"pattern", [IsolatedFromAbove, Symbol]> {
let arguments = (ins OptionalAttr<StrAttr>:$rootKind, let arguments = (ins OptionalAttr<StrAttr>:$rootKind,
Confined<I16Attr, [IntNonNegative]>:$benefit, Confined<I16Attr, [IntNonNegative]>:$benefit,
OptionalAttr<StrAttr>:$sym_name); OptionalAttr<SymbolNameAttr>:$sym_name);
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let assemblyFormat = [{
($sym_name^)? `:` `benefit` `(` $benefit `)`
(`,` `root` `(` $rootKind^ `)`)? attr-dict-with-keyword $body
}];
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, " OpBuilder<"OpBuilder &builder, OperationState &state, "
"Optional<StringRef> rootKind = llvm::None, " "Optional<StringRef> rootKind = llvm::None, "
@ -405,6 +409,12 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
Variadic<PDL_PositionalValue>:$externalArgs, Variadic<PDL_PositionalValue>:$externalArgs,
OptionalAttr<ArrayAttr>:$externalConstParams); OptionalAttr<ArrayAttr>:$externalConstParams);
let regions = (region AnyRegion:$body); let regions = (region AnyRegion:$body);
let assemblyFormat = [{
$root (`with` $name^ ($externalConstParams^)?
(`(` $externalArgs^ `:` type($externalArgs) `)`)?)?
($body^)?
attr-dict-with-keyword
}];
} }
def PDL_RewriteEndOp : PDL_Op<"rewrite_end", [Terminator, def PDL_RewriteEndOp : PDL_Op<"rewrite_end", [Terminator,

View File

@ -321,41 +321,6 @@ bool OperationOp::hasTypeInference() {
// pdl::PatternOp // pdl::PatternOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static ParseResult parsePatternOp(OpAsmParser &p, OperationState &state) {
StringAttr name;
p.parseOptionalSymbolName(name, SymbolTable::getSymbolAttrName(),
state.attributes);
// Parse the benefit.
IntegerAttr benefitAttr;
if (p.parseColon() || p.parseKeyword("benefit") || p.parseLParen() ||
p.parseAttribute(benefitAttr, p.getBuilder().getIntegerType(16),
"benefit", state.attributes) ||
p.parseRParen())
return failure();
// Parse the pattern body.
if (p.parseOptionalAttrDictWithKeyword(state.attributes) ||
p.parseRegion(*state.addRegion(), None, None))
return failure();
return success();
}
static void print(OpAsmPrinter &p, PatternOp op) {
p << "pdl.pattern";
if (Optional<StringRef> name = op.sym_name()) {
p << ' ';
p.printSymbolName(*name);
}
p << " : benefit(";
p.printAttributeWithoutType(op.benefitAttr());
p << ")";
p.printOptionalAttrDictWithKeyword(
op.getAttrs(), {"benefit", "rootKind", SymbolTable::getSymbolAttrName()});
p.printRegion(op.body());
}
static LogicalResult verify(PatternOp pattern) { static LogicalResult verify(PatternOp pattern) {
Region &body = pattern.body(); Region &body = pattern.body();
auto *term = body.front().getTerminator(); auto *term = body.front().getTerminator();
@ -445,72 +410,6 @@ static LogicalResult verify(ReplaceOp op) {
// pdl::RewriteOp // pdl::RewriteOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) {
// Parse the root operand.
OpAsmParser::OperandType rootOperand;
if (p.parseOperand(rootOperand) ||
p.resolveOperand(rootOperand, p.getBuilder().getType<OperationType>(),
state.operands))
return failure();
// Parse an external rewrite.
StringAttr nameAttr;
if (succeeded(p.parseOptionalKeyword("with"))) {
if (p.parseAttribute(nameAttr, "name", state.attributes))
return failure();
// Parse the optional set of constant parameters.
ArrayAttr constantParams;
OptionalParseResult constantParamResult = p.parseOptionalAttribute(
constantParams, "externalConstParams", state.attributes);
if (constantParamResult.hasValue() && failed(*constantParamResult))
return failure();
// Parse the optional additional arguments.
if (succeeded(p.parseOptionalLParen())) {
SmallVector<OpAsmParser::OperandType, 4> arguments;
SmallVector<Type, 4> argumentTypes;
llvm::SMLoc argumentLoc = p.getCurrentLocation();
if (p.parseOperandList(arguments) ||
p.parseColonTypeList(argumentTypes) || p.parseRParen() ||
p.resolveOperands(arguments, argumentTypes, argumentLoc,
state.operands))
return failure();
}
}
// If this isn't an external rewrite, parse the region body.
Region &rewriteRegion = *state.addRegion();
if (!nameAttr) {
if (p.parseRegion(rewriteRegion, /*arguments=*/llvm::None,
/*argTypes=*/llvm::None))
return failure();
RewriteOp::ensureTerminator(rewriteRegion, p.getBuilder(), state.location);
}
return p.parseOptionalAttrDictWithKeyword(state.attributes);
}
static void print(OpAsmPrinter &p, RewriteOp op) {
p << "pdl.rewrite " << op.root();
if (Optional<StringRef> name = op.name()) {
p << " with \"" << *name << "\"";
if (ArrayAttr constantParams = op.externalConstParamsAttr())
p << constantParams;
OperandRange externalArgs = op.externalArgs();
if (!externalArgs.empty())
p << "(" << externalArgs << " : " << externalArgs.getTypes() << ")";
} else {
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
{"name", "externalConstParams"});
}
static LogicalResult verify(RewriteOp op) { static LogicalResult verify(RewriteOp op) {
Region &rewriteRegion = op.body(); Region &rewriteRegion = op.body();