[mlir:ODS] Deprecate Op parser/printer fields in favor of a new hasCustomAssemblyFormat field

Currently if an operation wants a C++ implemented parser/printer, it specifies inline
code blocks. This is quite problematic for various reasons, e.g. it requires defining
C++ inside of Tablegen which is discouraged when possible, but mainly because
nearly all usages simply forward to static functions (e.g. `static void parseSomeOp(...)`)
with users devising their own standards for how these are defined.

This commit adds support for a `hasCustomAssemblyFormat` bit field that specifies if
a C++ parser/printer is needed, and when set to 1 declares the parse/print methods for
operations to override. For migration purposes, the existing behavior is untouched. Upstream
usages will be replaced in a followup to keep this patch focused on the new implementation.

Differential Revision: https://reviews.llvm.org/D119054
This commit is contained in:
River Riddle 2022-02-04 20:47:01 -08:00
parent 45084eab5e
commit d7f0083dca
5 changed files with 106 additions and 88 deletions

View File

@ -2442,14 +2442,24 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
// provided.
bit skipDefaultBuilders = 0;
// Custom parser.
// Custom parser and printer.
// NOTE: These fields are deprecated in favor of `assemblyFormat` or
// `hasCustomAssemblyFormat`, and are slated for deletion.
code parser = ?;
// Custom printer.
code printer = ?;
// Custom assembly format.
/// This field corresponds to a declarative description of the assembly format
/// for this operation. If populated, the `hasCustomAssemblyFormat` field is
/// ignored.
string assemblyFormat = ?;
/// This field indicates that the operation has a custom assembly format
/// implemented in C++. When set to `1` a `parse` and `print` method are generated
/// on the operation class. The operation should implement these methods to
/// support the custom format of the operation. The methods have the form:
/// * ParseResult parse(OpAsmParser &parser, OperationState &result)
/// * void print(OpAsmPrinter &p)
bit hasCustomAssemblyFormat = 0;
// A bit indicating if the operation has additional invariants that need to
// verified (aside from those verified by other ODS constructs). If set to `1`,

View File

@ -577,8 +577,8 @@ static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
OperationState &result) {
ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType argInfo;
Type argType = parser.getBuilder().getIndexType();
@ -593,12 +593,12 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
/*enableNameShadowing=*/true);
}
static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
void IsolatedRegionOp::print(OpAsmPrinter &p) {
p << "test.isolated_region ";
p.printOperand(op.getOperand());
p.shadowRegionArgs(op.getRegion(), op.getOperand());
p.printOperand(getOperand());
p.shadowRegionArgs(getRegion(), getOperand());
p << ' ';
p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
@ -613,16 +613,15 @@ RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
// Test GraphRegionOp
//===----------------------------------------------------------------------===//
static ParseResult parseGraphRegionOp(OpAsmParser &parser,
OperationState &result) {
ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
}
static void print(OpAsmPrinter &p, GraphRegionOp op) {
void GraphRegionOp::print(OpAsmPrinter &p) {
p << "test.graph_region ";
p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
RegionKind GraphRegionOp::getRegionKind(unsigned index) {
@ -633,24 +632,23 @@ RegionKind GraphRegionOp::getRegionKind(unsigned index) {
// Test AffineScopeOp
//===----------------------------------------------------------------------===//
static ParseResult parseAffineScopeOp(OpAsmParser &parser,
OperationState &result) {
ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
}
static void print(OpAsmPrinter &p, AffineScopeOp op) {
void AffineScopeOp::print(OpAsmPrinter &p) {
p << "test.affine_scope ";
p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// Test parser.
//===----------------------------------------------------------------------===//
static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
OperationState &result) {
ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
OperationState &result) {
if (parser.parseOptionalColon())
return success();
uint64_t numResults;
@ -663,13 +661,13 @@ static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
return success();
}
static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
if (unsigned numResults = op->getNumResults())
void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
if (unsigned numResults = getNumResults())
p << " : " << numResults;
}
static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
OperationState &result) {
ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
OperationState &result) {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return failure();
@ -677,15 +675,13 @@ static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
return success();
}
static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
p << " " << op.getKeyword();
}
void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
//===----------------------------------------------------------------------===//
// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
OperationState &result) {
ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
OperationState &result) {
if (parser.parseKeyword("wraps"))
return failure();
@ -715,9 +711,9 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
return success();
}
static void print(OpAsmPrinter &p, WrappingRegionOp op) {
void WrappingRegionOp::print(OpAsmPrinter &p) {
p << " wraps ";
p.printGenericOp(&op.getRegion().front().front());
p.printGenericOp(&getRegion().front().front());
}
//===----------------------------------------------------------------------===//
@ -726,8 +722,8 @@ static void print(OpAsmPrinter &p, WrappingRegionOp op) {
// parseCustomOperationName
//===----------------------------------------------------------------------===//
static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
OperationState &result) {
ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
OperationState &result) {
SMLoc loc = parser.getCurrentLocation();
Location currLocation = parser.getEncodedSourceLoc(loc);
@ -799,11 +795,11 @@ static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
return success();
}
static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
p << ' ';
p.printOperands(op.getOperands());
p.printOperands(getOperands());
Operation &innerOp = op.getRegion().front().front();
Operation &innerOp = getRegion().front().front();
// Assuming that region has a single non-terminator inner-op, if the inner-op
// meets some criteria (which in this case is a simple one based on the name
// of inner-op), then we can print the entire region in a succinct way.
@ -813,19 +809,19 @@ static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
p << " start special.op end";
} else {
p << " (";
p.printRegion(op.getRegion());
p.printRegion(getRegion());
p << ")";
}
p << " : ";
p.printFunctionalType(op);
p.printFunctionalType(*this);
}
//===----------------------------------------------------------------------===//
// Test PolyForOp - parse list of region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
// Parse list of region arguments without a delimiter.
if (parser.parseRegionArgumentList(ivsInfo))
@ -838,6 +834,8 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
return parser.parseRegion(*body, ivsInfo, argTypes);
}
void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
void PolyForOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
@ -1044,8 +1042,8 @@ void SideEffectOp::getEffects(
//===----------------------------------------------------------------------===//
// This op has fancy handling of its SSA result name.
static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
OperationState &result) {
ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
OperationState &result) {
// Add the result types.
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
result.addTypes(parser.getBuilder().getIntegerType(32));
@ -1081,19 +1079,19 @@ static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
return success();
}
static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
// Note that we only need to print the "name" attribute if the asmprinter
// result name disagrees with it. This can happen in strange cases, e.g.
// when there are conflicts.
bool namesDisagree = op.getNames().size() != op.getNumResults();
bool namesDisagree = getNames().size() != getNumResults();
SmallString<32> resultNameStr;
for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
resultNameStr.clear();
llvm::raw_svector_ostream tmpStream(resultNameStr);
p.printOperand(op.getResult(i), tmpStream);
p.printOperand(getResult(i), tmpStream);
auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
auto expectedName = getNames()[i].dyn_cast<StringAttr>();
if (!expectedName ||
tmpStream.str().drop_front() != expectedName.getValue()) {
namesDisagree = true;
@ -1101,9 +1099,9 @@ static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
}
if (namesDisagree)
p.printOptionalAttrDictWithKeyword(op->getAttrs());
p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
else
p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
}
// We set the SSA name in the asm syntax to the contents of the name
@ -1142,27 +1140,26 @@ LogicalResult AttrWithTraitOp::verify() {
// RegionIfOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, RegionIfOp op) {
void RegionIfOp::print(OpAsmPrinter &p) {
p << " ";
p.printOperands(op.getOperands());
p << ": " << op.getOperandTypes();
p.printArrowTypeList(op.getResultTypes());
p.printOperands(getOperands());
p << ": " << getOperandTypes();
p.printArrowTypeList(getResultTypes());
p << " then ";
p.printRegion(op.getThenRegion(),
p.printRegion(getThenRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
p << " else ";
p.printRegion(op.getElseRegion(),
p.printRegion(getElseRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
p << " join ";
p.printRegion(op.getJoinRegion(),
p.printRegion(getJoinRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
}
static ParseResult parseRegionIfOp(OpAsmParser &parser,
OperationState &result) {
ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> operandInfos;
SmallVector<Type, 2> operandTypes;
@ -1241,17 +1238,17 @@ void AnyCondOp::getRegionInvocationBounds(
// SingleNoTerminatorCustomAsmOp
//===----------------------------------------------------------------------===//
static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
OperationState &state) {
ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
OperationState &state) {
Region *body = state.addRegion();
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
return success();
}
static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
printer.printRegion(
op.getRegion(), /*printEntryBlockArgs=*/false,
getRegion(), /*printEntryBlockArgs=*/false,
// This op has a single block without terminators. But explicitly mark
// as not printing block terminators for testing.
/*printBlockTerminators=*/false);

View File

@ -360,8 +360,7 @@ def SingleNoTerminatorOp : TEST_Op<"single_no_terminator_op",
def SingleNoTerminatorCustomAsmOp : TEST_Op<"single_no_terminator_custom_asm_op",
[SingleBlock, NoTerminator]> {
let regions = (region SizedRegion<1>);
let parser = [{ return ::parseSingleNoTerminatorCustomAsmOp(parser, result); }];
let printer = [{ return ::print(*this, p); }];
let hasCustomAssemblyFormat = 1;
}
def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op",
@ -644,9 +643,7 @@ def StringAttrPrettyNameOp
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let arguments = (ins StrArrayAttr:$names);
let results = (outs Variadic<I32>:$r);
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasCustomAssemblyFormat = 1;
}
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
@ -1580,14 +1577,12 @@ def TestSignatureConversionNoConverterOp
def ParseIntegerLiteralOp : TEST_Op<"parse_integer_literal"> {
let results = (outs Variadic<Index>:$results);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let hasCustomAssemblyFormat = 1;
}
def ParseWrappedKeywordOp : TEST_Op<"parse_wrapped_keyword"> {
let arguments = (ins StrAttr:$keyword);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@ -1602,8 +1597,7 @@ def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> {
let arguments = (ins Index);
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let hasCustomAssemblyFormat = 1;
}
def SSACFGRegionOp : TEST_Op<"ssacfg_region", [
@ -1626,8 +1620,7 @@ def GraphRegionOp : TEST_Op<"graph_region", [
}];
let regions = (region AnyRegion:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let hasCustomAssemblyFormat = 1;
}
def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {
@ -1637,8 +1630,7 @@ def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {
}];
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let hasCustomAssemblyFormat = 1;
}
def WrappingRegionOp : TEST_Op<"wrapping_region",
@ -1651,8 +1643,7 @@ def WrappingRegionOp : TEST_Op<"wrapping_region",
let results = (outs Variadic<AnyType>);
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let hasCustomAssemblyFormat = 1;
}
def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
@ -1670,12 +1661,10 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
let results = (outs AnyType);
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let hasCustomAssemblyFormat = 1;
}
def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
{
def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> {
let summary = "polyfor operation";
let description = [{
Test op with multiple region arguments, each argument of index type.
@ -1685,7 +1674,7 @@ def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
mlir::OpAsmSetValueNameFn setNameFn);
}];
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@ -2356,8 +2345,6 @@ def RegionIfOp : TEST_Op<"region_if",
parent op.
}];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseRegionIfOp(parser, result); }];
let arguments = (ins Variadic<AnyType>);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$thenRegion,
@ -2375,6 +2362,7 @@ def RegionIfOp : TEST_Op<"region_if",
}
::mlir::OperandRange getSuccessorEntryOperands(unsigned index);
}];
let hasCustomAssemblyFormat = 1;
}
def AnyCondOp : TEST_Op<"any_cond",

View File

@ -38,8 +38,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
);
let builders = [OpBuilder<(ins "Value":$val)>,
OpBuilder<(ins CArg<"int", "0">:$integer)>];
let parser = [{ foo }];
let printer = [{ bar }];
let hasCustomAssemblyFormat = 1;
let hasCanonicalizer = 1;
let hasFolder = 1;

View File

@ -2125,13 +2125,29 @@ void OpEmitter::genTypeInterfaceMethods() {
}
void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser") ||
hasStringAttribute(def, "assemblyFormat"))
if (hasStringAttribute(def, "assemblyFormat"))
return;
bool hasCppFormat = def.getValueAsBit("hasCustomAssemblyFormat");
if (!hasStringAttribute(def, "parser") && !hasCppFormat)
return;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
paramList.emplace_back("::mlir::OperationState &", "result");
// If this uses the cpp format, only generate a declaration.
if (hasCppFormat) {
auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse",
std::move(paramList));
ERROR_IF_PRUNED(method, "parse", op);
return;
}
PrintNote(op.getLoc(),
"`parser` and `printer` fields are deprecated and will be removed, "
"please use the `hasCustomAssemblyFormat` field instead");
auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
std::move(paramList));
ERROR_IF_PRUNED(method, "parse", op);
@ -2146,6 +2162,14 @@ void OpEmitter::genPrinter() {
if (hasStringAttribute(def, "assemblyFormat"))
return;
// If this uses the cpp format, only generate a declaration.
if (def.getValueAsBit("hasCustomAssemblyFormat")) {
auto *method = opClass.declareMethod(
"void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
ERROR_IF_PRUNED(method, "print", op);
return;
}
auto *valueInit = def.getValueInit("printer");
StringInit *stringInit = dyn_cast<StringInit>(valueInit);
if (!stringInit)