diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 73b4f263accd..71dae35b38d7 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -296,7 +296,7 @@ def PDL_PatternOp : PDL_Op<"pattern", [IsolatedFromAbove, Symbol]> { %resultType = pdl.type %inputOperand = pdl.input %root, %results = pdl.operation "foo.op"(%inputOperand) -> (%resultType) - pdl.rewrite(%root) { + pdl.rewrite %root { pdl.replace %root with (%inputOperand) } } @@ -305,9 +305,13 @@ def PDL_PatternOp : PDL_Op<"pattern", [IsolatedFromAbove, Symbol]> { let arguments = (ins OptionalAttr:$rootKind, Confined:$benefit, - OptionalAttr:$sym_name); - + OptionalAttr:$sym_name); let regions = (region SizedRegion<1>:$body); + let assemblyFormat = [{ + ($sym_name^)? `:` `benefit` `(` $benefit `)` + (`,` `root` `(` $rootKind^ `)`)? attr-dict-with-keyword $body + }]; + let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, " "Optional rootKind = llvm::None, " @@ -405,6 +409,12 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [ Variadic:$externalArgs, OptionalAttr:$externalConstParams); let regions = (region AnyRegion:$body); + let assemblyFormat = [{ + $root (`with` $name^ ($externalConstParams^)? + (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? + ($body^)? + attr-dict-with-keyword + }]; } def PDL_RewriteEndOp : PDL_Op<"rewrite_end", [Terminator, diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index 0146f0d50b88..082229b6b394 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -321,41 +321,6 @@ bool OperationOp::hasTypeInference() { // pdl::PatternOp //===----------------------------------------------------------------------===// -static ParseResult parsePatternOp(OpAsmParser &p, OperationState &state) { - StringAttr name; - p.parseOptionalSymbolName(name, SymbolTable::getSymbolAttrName(), - state.attributes); - - // Parse the benefit. - IntegerAttr benefitAttr; - if (p.parseColon() || p.parseKeyword("benefit") || p.parseLParen() || - p.parseAttribute(benefitAttr, p.getBuilder().getIntegerType(16), - "benefit", state.attributes) || - p.parseRParen()) - return failure(); - - // Parse the pattern body. - if (p.parseOptionalAttrDictWithKeyword(state.attributes) || - p.parseRegion(*state.addRegion(), None, None)) - return failure(); - return success(); -} - -static void print(OpAsmPrinter &p, PatternOp op) { - p << "pdl.pattern"; - if (Optional name = op.sym_name()) { - p << ' '; - p.printSymbolName(*name); - } - p << " : benefit("; - p.printAttributeWithoutType(op.benefitAttr()); - p << ")"; - - p.printOptionalAttrDictWithKeyword( - op.getAttrs(), {"benefit", "rootKind", SymbolTable::getSymbolAttrName()}); - p.printRegion(op.body()); -} - static LogicalResult verify(PatternOp pattern) { Region &body = pattern.body(); auto *term = body.front().getTerminator(); @@ -445,72 +410,6 @@ static LogicalResult verify(ReplaceOp op) { // pdl::RewriteOp //===----------------------------------------------------------------------===// -static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) { - // Parse the root operand. - OpAsmParser::OperandType rootOperand; - if (p.parseOperand(rootOperand) || - p.resolveOperand(rootOperand, p.getBuilder().getType(), - state.operands)) - return failure(); - - // Parse an external rewrite. - StringAttr nameAttr; - if (succeeded(p.parseOptionalKeyword("with"))) { - if (p.parseAttribute(nameAttr, "name", state.attributes)) - return failure(); - - // Parse the optional set of constant parameters. - ArrayAttr constantParams; - OptionalParseResult constantParamResult = p.parseOptionalAttribute( - constantParams, "externalConstParams", state.attributes); - if (constantParamResult.hasValue() && failed(*constantParamResult)) - return failure(); - - // Parse the optional additional arguments. - if (succeeded(p.parseOptionalLParen())) { - SmallVector arguments; - SmallVector argumentTypes; - llvm::SMLoc argumentLoc = p.getCurrentLocation(); - if (p.parseOperandList(arguments) || - p.parseColonTypeList(argumentTypes) || p.parseRParen() || - p.resolveOperands(arguments, argumentTypes, argumentLoc, - state.operands)) - return failure(); - } - } - - // If this isn't an external rewrite, parse the region body. - Region &rewriteRegion = *state.addRegion(); - if (!nameAttr) { - if (p.parseRegion(rewriteRegion, /*arguments=*/llvm::None, - /*argTypes=*/llvm::None)) - return failure(); - RewriteOp::ensureTerminator(rewriteRegion, p.getBuilder(), state.location); - } - - return p.parseOptionalAttrDictWithKeyword(state.attributes); -} - -static void print(OpAsmPrinter &p, RewriteOp op) { - p << "pdl.rewrite " << op.root(); - if (Optional name = op.name()) { - p << " with \"" << *name << "\""; - - if (ArrayAttr constantParams = op.externalConstParamsAttr()) - p << constantParams; - - OperandRange externalArgs = op.externalArgs(); - if (!externalArgs.empty()) - p << "(" << externalArgs << " : " << externalArgs.getTypes() << ")"; - } else { - p.printRegion(op.body(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - } - - p.printOptionalAttrDictWithKeyword(op.getAttrs(), - {"name", "externalConstParams"}); -} - static LogicalResult verify(RewriteOp op) { Region &rewriteRegion = op.body();