Add a `qualified` directive to the Op, Attribute, and Type declarative assembly format

This patch introduces a new directive that allow to parse/print attributes and types fully
qualified.
This is a follow-up to ee0908703d which introduces the eliding of the `!dialect.mnemonic` by default and allows to force to fully qualify each type/attribute
individually.

Differential Revision: https://reviews.llvm.org/D116905
This commit is contained in:
Mehdi Amini 2022-01-11 01:26:44 +00:00
parent 140a6b1e5c
commit 63f0c00d38
12 changed files with 202 additions and 12 deletions

View File

@ -651,6 +651,16 @@ The available directives are as follows:
- `input` must be either an operand or result [variable](#variables), the
`operands` directive, or the `results` directive.
* `qualified` ( type_or_attribute )
- Wraps a `type` directive or an attribute parameter.
- Used to force printing the type or attribute prefixed with its dialect
and mnemonic. For example the `vector.multi_reduction` operation has a
`kind` attribute ; by default the declarative assembly will print:
`vector.multi_reduction <minf>, ...` but using `qualified($kind)` in the
declarative assembly format will print it instead as:
`vector.multi_reduction #vector.kind<minf>, ...`.
#### Literals
A literal is either a keyword or punctuation surrounded by \`\`.

View File

@ -489,9 +489,11 @@ for these parameters are expected to return `FailureOr<$cppStorageType>`.
Attribute and type assembly formats have the following directives:
* `params`: capture all parameters of an attribute or type.
* `struct`: generate a "struct-like" parser and printer for a list of key-value
pairs.
* `params`: capture all parameters of an attribute or type.
* `qualified`: mark a parameter to be printed with its leading dialect and
mnemonic.
* `struct`: generate a "struct-like" parser and printer for a list of
key-value pairs.
#### `params` Directive
@ -517,6 +519,34 @@ The `params` directive can also be passed to other directives, such as `struct`,
as an argument that refers to all parameters in place of explicitly listing all
parameters as variables.
#### `qualified` Directive
This directive can be used to wrap attribute or type parameters such that they
are printed in a fully qualified form, i.e., they include the dialect name and
mnemonic prefix.
For example:
```tablegen
def OuterType : TypeDef<My_Dialect, "MyOuterType"> {
let parameters = (ins MyPairType:$inner);
let mnemonic = "outer";
let assemblyFormat = "`<` pair `:` $inner `>`";
}
def OuterQualifiedType : TypeDef<My_Dialect, "MyOuterQualifiedType"> {
let parameters = (ins MyPairType:$inner);
let mnemonic = "outer_qual";
let assemblyFormat = "`<` pair `:` qualified($inner) `>`";
}
```
In the IR, the types will appear as:
```mlir
!my_dialect.outer<pair : <42, 24>>
!my_dialect.outer_qual<pair : !mydialect.pair<42, 24>>
```
#### `struct` Directive
The `struct` directive accepts a list of variables to capture and will generate

View File

@ -144,6 +144,14 @@ def CompoundNestedOuter : Test_Attr<"CompoundNestedOuter"> {
let assemblyFormat = "`<` `i` $inner `>`";
}
def CompoundNestedOuterQual : Test_Attr<"CompoundNestedOuterQual"> {
let mnemonic = "cmpnd_nested_outer_qual";
// List of type parameters.
let parameters = (ins CompoundNestedInner:$inner);
let assemblyFormat = "`<` `i` qualified($inner) `>`";
}
def TestParamOne : AttrParameter<"int64_t", ""> {}
def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {

View File

@ -1955,11 +1955,21 @@ def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> {
let assemblyFormat = "`nested` $nested attr-dict-with-keyword";
}
def FormatQualifiedCompoundAttr : TEST_Op<"format_qual_cpmd_nested_attr"> {
let arguments = (ins CompoundNestedOuter:$nested);
let assemblyFormat = "`nested` qualified($nested) attr-dict-with-keyword";
}
def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> {
let arguments = (ins CompoundNestedOuterType:$nested);
let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword";
}
def FormatQualifiedNestedType : TEST_Op<"format_qual_cpmd_nested_type"> {
let arguments = (ins CompoundNestedOuterType:$nested);
let assemblyFormat = "$nested `nested` qualified(type($nested)) attr-dict-with-keyword";
}
//===----------------------------------------------------------------------===//
// Custom Directives

View File

@ -65,12 +65,20 @@ def CompoundNestedInnerType : Test_Type<"CompoundNestedInner"> {
def CompoundNestedOuterType : Test_Type<"CompoundNestedOuter"> {
let mnemonic = "cmpnd_nested_outer";
// List of type parameters.
let parameters = (ins CompoundNestedInnerType:$inner);
let assemblyFormat = "`<` `i` $inner `>`";
}
def CompoundNestedOuterTypeQual : Test_Type<"CompoundNestedOuterQual"> {
let mnemonic = "cmpnd_nested_outer_qual";
// List of type parameters.
let parameters = (
ins
CompoundNestedInnerType:$inner
);
let assemblyFormat = "`<` `i` $inner `>`";
let assemblyFormat = "`<` `i` qualified($inner) `>`";
}
// An example of how one could implement a standard integer.

View File

@ -301,6 +301,24 @@ module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpl
// CHECK: test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
//-----
// CHECK: test.format_qual_cpmd_nested_attr nested #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>
test.format_qual_cpmd_nested_attr nested #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>
//-----
// Check the `qualified` directive in the declarative assembly format.
// CHECK: @qualifiedCompoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
func @qualifiedCompoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>) -> () {
// Verify that the type prefix is not elided
// CHECK: format_qual_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>
test.format_qual_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>
return
}
//-----
//===----------------------------------------------------------------------===//
// Format custom directives
//===----------------------------------------------------------------------===//

View File

@ -9,3 +9,7 @@ func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [5, 6]
// CHECK: test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16
%b = test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16
// CHECK-LABEL: @qualifiedAttr()
// CHECK-SAME: #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>
func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>}

View File

@ -29,6 +29,10 @@ func @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i !test.cmpnd_inner
return
}
// CHECK-LABEL: @compoundNestedQual
// CHECK-SAME: !test.cmpnd_nested_outer_qual<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>
func private @compoundNestedQual(%arg0: !test.cmpnd_nested_outer_qual<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>) -> ()
// CHECK: @testInt(%arg0: !test.int<signed, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<none, 1>)
func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
return

View File

@ -89,7 +89,15 @@ public:
/// Get the parameter in the element.
const AttrOrTypeParameter &getParam() const { return param; }
/// Indicate if this variable is printed "qualified" (that is it is
/// prefixed with the `#dialect.mnemonic`).
bool shouldBeQualified() { return shouldBeQualifiedFlag; }
void setShouldBeQualified(bool qualified = true) {
shouldBeQualifiedFlag = qualified;
}
private:
bool shouldBeQualifiedFlag = false;
AttrOrTypeParameter param;
};
@ -166,6 +174,10 @@ static const char *const defaultParameterParser =
static const char *const defaultParameterPrinter =
"$_printer.printStrippedAttrOrType($_self)";
/// Qualified printer for attribute or type parameters: it does not elide
/// dialect and mnemonic.
static const char *const qualifiedParameterPrinter = "$_printer << $_self";
/// Print an error when failing to parse an element.
///
/// $0: The parameter C++ class name.
@ -251,7 +263,7 @@ private:
void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a variable.
void genVariablePrinter(const AttrOrTypeParameter &param, FmtContext &ctx,
MethodBody &os);
MethodBody &os, bool printQualified = false);
/// Generate the printer code for a `params` directive.
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
@ -435,7 +447,8 @@ void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructPrinter(strct, ctx, os);
if (auto *var = dyn_cast<VariableElement>(el))
return genVariablePrinter(var->getParam(), ctx, os);
return genVariablePrinter(var->getParam(), ctx, os,
var->shouldBeQualified());
llvm_unreachable("unknown format element");
}
@ -455,7 +468,8 @@ void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
}
void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter &param,
FmtContext &ctx, MethodBody &os) {
FmtContext &ctx, MethodBody &os,
bool printQualified) {
/// Insert a space before the next parameter, if necessary.
if (shouldEmitSpace || !lastWasPunctuation)
os << tgfmt(" $_printer << ' ';\n", &ctx);
@ -464,7 +478,9 @@ void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter &param,
ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
os << " ";
if (auto printer = param.getPrinter())
if (printQualified)
os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n";
else if (auto printer = param.getPrinter())
os << tgfmt(*printer, &ctx) << ";\n";
else
os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
@ -546,6 +562,9 @@ private:
FailureOr<std::unique_ptr<Element>> parseDirective(ParserContext ctx);
/// Parse a `params` directive.
FailureOr<std::unique_ptr<Element>> parseParamsDirective();
/// Parse a `qualified` directive.
FailureOr<std::unique_ptr<Element>>
parseQualifiedDirective(ParserContext ctx);
/// Parse a `struct` directive.
FailureOr<std::unique_ptr<Element>> parseStructDirective();
@ -643,6 +662,8 @@ FailureOr<std::unique_ptr<Element>>
FormatParser::parseDirective(ParserContext ctx) {
switch (curToken.getKind()) {
case FormatToken::kw_qualified:
return parseQualifiedDirective(ctx);
case FormatToken::kw_params:
return parseParamsDirective();
case FormatToken::kw_struct:
@ -656,6 +677,24 @@ FormatParser::parseDirective(ParserContext ctx) {
}
}
FailureOr<std::unique_ptr<Element>>
FormatParser::parseQualifiedDirective(ParserContext ctx) {
consumeToken();
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")))
return failure();
FailureOr<std::unique_ptr<Element>> var = parseElement(ctx);
if (failed(var))
return var;
if (!isa<VariableElement>(*var))
return emitError("`qualified` argument list expected a variable");
cast<VariableElement>(var->get())->setShouldBeQualified();
if (failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();
return var;
}
FailureOr<std::unique_ptr<Element>> FormatParser::parseParamsDirective() {
consumeToken();
/// Collect all of the attribute's or type's parameters.

View File

@ -172,6 +172,7 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
.Case("struct", FormatToken::kw_struct)
.Case("successors", FormatToken::kw_successors)
.Case("type", FormatToken::kw_type)
.Case("qualified", FormatToken::kw_qualified)
.Default(FormatToken::identifier);
return FormatToken(kind, str);
}

View File

@ -59,6 +59,7 @@ public:
kw_functional_type,
kw_operands,
kw_params,
kw_qualified,
kw_ref,
kw_regions,
kw_results,

View File

@ -117,6 +117,16 @@ struct AttributeVariable
bool isUnitAttr() const {
return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
}
/// Indicate if this attribute is printed "qualified" (that is it is
/// prefixed with the `#dialect.mnemonic`).
bool shouldBeQualified() { return shouldBeQualifiedFlag; }
void setShouldBeQualified(bool qualified = true) {
shouldBeQualifiedFlag = qualified;
}
private:
bool shouldBeQualifiedFlag = false;
};
/// This class represents a variable that refers to an operand argument.
@ -237,9 +247,18 @@ public:
TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
Element *getOperand() const { return operand.get(); }
/// Indicate if this type is printed "qualified" (that is it is
/// prefixed with the `!dialect.mnemonic`).
bool shouldBeQualified() { return shouldBeQualifiedFlag; }
void setShouldBeQualified(bool qualified = true) {
shouldBeQualifiedFlag = qualified;
}
private:
/// The operand that is used to format the directive.
std::unique_ptr<Element> operand;
bool shouldBeQualifiedFlag = false;
};
} // namespace
@ -658,6 +677,10 @@ const char *const typeParserCode = R"(
{1}RawTypes[0] = type;
}
)";
const char *const qualifiedTypeParserCode = R"(
if (parser.parseType({1}RawTypes[0]))
return ::mlir::failure();
)";
/// The code snippet used to generate a parser call for a functional type.
///
@ -1296,7 +1319,8 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
if (var->attr.isOptional()) {
body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
} else {
if (var->attr.getStorageType() == "::mlir::Attribute")
if (attr->shouldBeQualified() ||
var->attr.getStorageType() == "::mlir::Attribute")
body << formatv(genericAttrParserCode, var->name, attrTypeStr);
else
body << formatv(attrParserCode, var->name, attrTypeStr);
@ -1368,14 +1392,16 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
} else if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(optionalTypeParserCode, listName);
} else {
const char *parserCode =
dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
TypeSwitch<Element *>(dir->getOperand())
.Case<OperandVariable, ResultVariable>([&](auto operand) {
body << formatv(typeParserCode,
body << formatv(parserCode,
operand->getVar()->constraint.getCPPClassName(),
listName);
})
.Default([&](auto operand) {
body << formatv(typeParserCode, "::mlir::Type", listName);
body << formatv(parserCode, "::mlir::Type", listName);
});
}
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
@ -2025,7 +2051,8 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
else if (var->attr.isOptional())
body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name)
<< "Attr());\n";
else if (var->attr.getStorageType() == "::mlir::Attribute")
else if (attr->shouldBeQualified() ||
var->attr.getStorageType() == "::mlir::Attribute")
body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name)
<< "Attr());\n";
else
@ -2093,6 +2120,11 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
!var->isOptional()) {
std::string cppClass = var->constraint.getCPPClassName();
if (dir->shouldBeQualified()) {
body << " _odsPrinter << " << op.getGetterName(var->name)
<< "().getType();\n";
return;
}
body << " {\n"
<< " auto type = " << op.getGetterName(var->name)
<< "().getType();\n"
@ -2253,6 +2285,8 @@ private:
ParserContext context);
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, ParserContext context);
LogicalResult parseQualifiedDirective(std::unique_ptr<Element> &element,
FormatToken tok, ParserContext context);
LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, ParserContext context);
LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
@ -2762,6 +2796,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
return parseFunctionalTypeDirective(element, dirTok, context);
case FormatToken::kw_operands:
return parseOperandsDirective(element, dirTok.getLoc(), context);
case FormatToken::kw_qualified:
return parseQualifiedDirective(element, dirTok, context);
case FormatToken::kw_regions:
return parseRegionsDirective(element, dirTok.getLoc(), context);
case FormatToken::kw_results:
@ -3176,6 +3212,27 @@ FormatParser::parseTypeDirective(std::unique_ptr<Element> &element,
return ::mlir::success();
}
LogicalResult
FormatParser::parseQualifiedDirective(std::unique_ptr<Element> &element,
FormatToken tok, ParserContext context) {
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
failed(parseElement(element, context)) ||
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();
if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
attr->setShouldBeQualified();
} else if (auto *type = dyn_cast<TypeDirective>(element.get())) {
type->setShouldBeQualified();
} else {
return emitError(
tok.getLoc(),
"'qualified' directive expects an attribute or a `type` directive");
}
return success();
}
LogicalResult
FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
bool isRefChild) {