[mlir][OpFormatGen] Add support for specifiy "custom" directives.

This revision adds support for custom directives to the declarative assembly format. This allows for users to use C++ for printing and parsing subsections of an otherwise declaratively specified format. The custom directive is structured as follows:

```
custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)`
```

`user-directive` is used as a suffix when this directive is used during printing and parsing. When parsing, `parseUserDirective` will be invoked. When printing, `printUserDirective` will be invoked. The first parameter to these methods must be a reference to either the OpAsmParser, or OpAsmPrinter. The type of rest of the parameters is dependent on the `Params` specified in the assembly format.

Differential Revision: https://reviews.llvm.org/D84719
This commit is contained in:
River Riddle 2020-08-31 12:33:36 -07:00
parent 61e15ecab5
commit 88c6e25e4f
7 changed files with 643 additions and 37 deletions

View File

@ -664,6 +664,12 @@ The available directives are as follows:
- Represents the attribute dictionary of the operation, but prefixes the
dictionary with an `attributes` keyword.
* `custom` < UserDirective > ( Params )
- Represents a custom directive implemented by the user in C++.
- See the [Custom Directives](#custom-directives) section below for more
details.
* `functional-type` ( inputs , results )
- Formats the `inputs` and `results` arguments as a
@ -705,6 +711,75 @@ example above, the variables would be `$callee` and `$args`.
Attribute variables are printed with their respective value type, unless that
value type is buildable. In those cases, the type of the attribute is elided.
#### Custom Directives
The declarative assembly format specification allows for handling a large
majority of the common cases when formatting an operation. For the operations
that require or desire specifying parts of the operation in a form not supported
by the declarative syntax, custom directives may be specified. A custom
directive essentially allows for users to use C++ for printing and parsing
subsections of an otherwise declaratively specified format. Looking at the
specification of a custom directive above:
```
custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)`
```
A custom directive has two main parts: The `UserDirective` and the `Params`. A
custom directive is transformed into a call to a `print*` and a `parse*` method
when generating the C++ code for the format. The `UserDirective` is an
identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
would result in calls to `parseMyDirective` and `printMyDirective` wihtin the
parser and printer respectively. `Params` may be any combination of variables
(i.e. Attribute, Operand, Successor, etc.) and type directives. The type
directives must refer to a variable, but that variable need not also be a
parameter to the custom directive.
The arguments to the `parse<UserDirective>` method is firstly a reference to the
`OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
corresponding to the parameters specified in the format. The mapping of
declarative parameter to `parse` method argument is detailed below:
* Attribute Variables
- Single: `<Attribute-Storage-Type>(e.g. Attribute) &`
- Optional: `<Attribute-Storage-Type>(e.g. Attribute) &`
* Operand Variables
- Single: `OpAsmParser::OperandType &`
- Optional: `Optional<OpAsmParser::OperandType> &`
- Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
* Successor Variables
- Single: `Block *&`
- Variadic: `SmallVectorImpl<Block *> &`
* Type Directives
- Single: `Type &`
- Optional: `Type &`
- Variadic: `SmallVectorImpl<Type> &`
When a variable is optional, the value should only be specified if the variable
is present. Otherwise, the value should remain `None` or null.
The arguments to the `print<UserDirective>` method is firstly a reference to the
`OpAsmPrinter`(`OpAsmPrinter &`), and secondly a set of output parameters
corresponding to the parameters specified in the format. The mapping of
declarative parameter to `print` method argument is detailed below:
* Attribute Variables
- Single: `<Attribute-Storage-Type>(e.g. Attribute)`
- Optional: `<Attribute-Storage-Type>(e.g. Attribute)`
* Operand Variables
- Single: `Value`
- Optional: `Value`
- Variadic: `OperandRange`
* Successor Variables
- Single: `Block *`
- Variadic: `SuccessorRange`
* Type Directives
- Single: `Type`
- Optional: `Type`
- Variadic: `TypeRange`
When a variable is optional, the provided value may be null.
#### Optional Groups
In certain situations operations may have "optional" information, e.g.
@ -722,8 +797,8 @@ information. An optional group is defined by wrapping a set of elements within
should be printed/parsed.
- An element is marked as the anchor by adding a trailing `^`.
- The first element is *not* required to be the anchor of the group.
* Literals, variables, and type directives are the only valid elements within
the group.
* Literals, variables, custom directives, and type directives are the only
valid elements within the group.
- Any attribute variable may be used, but only optional attributes can be
marked as the anchor.
- Only variadic or optional operand arguments can be used.

View File

@ -202,6 +202,10 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
llvm::interleaveComma(types, p);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) {
llvm::interleaveComma(types, p);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
llvm::interleaveComma(types, p);
return p;

View File

@ -267,6 +267,108 @@ void FoldToCallOp::getCanonicalizationPatterns(
results.insert<FoldToCallOpPattern>(context);
}
//===----------------------------------------------------------------------===//
// Test Format* operations
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Parsing
static ParseResult parseCustomDirectiveOperands(
OpAsmParser &parser, OpAsmParser::OperandType &operand,
Optional<OpAsmParser::OperandType> &optOperand,
SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
if (parser.parseOperand(operand))
return failure();
if (succeeded(parser.parseOptionalComma())) {
optOperand.emplace();
if (parser.parseOperand(*optOperand))
return failure();
}
if (parser.parseArrow() || parser.parseLParen() ||
parser.parseOperandList(varOperands) || parser.parseRParen())
return failure();
return success();
}
static ParseResult
parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
Type &optOperandType,
SmallVectorImpl<Type> &varOperandTypes) {
if (parser.parseColon())
return failure();
if (parser.parseType(operandType))
return failure();
if (succeeded(parser.parseOptionalComma())) {
if (parser.parseType(optOperandType))
return failure();
}
if (parser.parseArrow() || parser.parseLParen() ||
parser.parseTypeList(varOperandTypes) || parser.parseRParen())
return failure();
return success();
}
static ParseResult parseCustomDirectiveOperandsAndTypes(
OpAsmParser &parser, OpAsmParser::OperandType &operand,
Optional<OpAsmParser::OperandType> &optOperand,
SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
parseCustomDirectiveResults(parser, operandType, optOperandType,
varOperandTypes))
return failure();
return success();
}
static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
SmallVectorImpl<Block *> &varSuccessors) {
if (parser.parseSuccessor(successor))
return failure();
if (failed(parser.parseOptionalComma()))
return success();
Block *varSuccessor;
if (parser.parseSuccessor(varSuccessor))
return failure();
varSuccessors.append(2, varSuccessor);
return success();
}
//===----------------------------------------------------------------------===//
// Printing
static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand,
Value optOperand,
OperandRange varOperands) {
printer << operand;
if (optOperand)
printer << ", " << optOperand;
printer << " -> (" << varOperands << ")";
}
static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
Type optOperandType,
TypeRange varOperandTypes) {
printer << " : " << operandType;
if (optOperandType)
printer << ", " << optOperandType;
printer << " -> (" << varOperandTypes << ")";
}
static void
printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
Value optOperand, OperandRange varOperands,
Type operandType, Type optOperandType,
TypeRange varOperandTypes) {
printCustomDirectiveOperands(printer, operand, optOperand, varOperands);
printCustomDirectiveResults(printer, operandType, optOperandType,
varOperandTypes);
}
static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
Block *successor,
SuccessorRange varSuccessors) {
printer << successor;
if (!varSuccessors.empty())
printer << ", " << varSuccessors.front();
}
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//

View File

@ -1414,8 +1414,60 @@ def FormatOptionalUnitAttrNoElide
}
//===----------------------------------------------------------------------===//
// AllTypesMatch type inference
// Custom Directives
def FormatCustomDirectiveOperands
: TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> {
let arguments = (ins I64:$operand, Optional<I64>:$optOperand,
Variadic<I64>:$varOperands);
let assemblyFormat = [{
custom<CustomDirectiveOperands>(
$operand, $optOperand, $varOperands
)
attr-dict
}];
}
def FormatCustomDirectiveOperandsAndTypes
: TEST_Op<"format_custom_directive_operands_and_types",
[AttrSizedOperandSegments]> {
let arguments = (ins AnyType:$operand, Optional<AnyType>:$optOperand,
Variadic<AnyType>:$varOperands);
let assemblyFormat = [{
custom<CustomDirectiveOperandsAndTypes>(
$operand, $optOperand, $varOperands,
type($operand), type($optOperand), type($varOperands)
)
attr-dict
}];
}
def FormatCustomDirectiveResults
: TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
Variadic<AnyType>:$varResults);
let assemblyFormat = [{
custom<CustomDirectiveResults>(
type($result), type($optResult), type($varResults)
)
attr-dict
}];
}
def FormatCustomDirectiveSuccessors
: TEST_Op<"format_custom_directive_successors", [Terminator]> {
let successors = (successor AnySuccessor:$successor,
VariadicSuccessor<AnySuccessor>:$successors);
let assemblyFormat = [{
custom<CustomDirectiveSuccessors>(
$successor, $successors
)
attr-dict
}];
}
//===----------------------------------------------------------------------===//
// AllTypesMatch type inference
def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
AllTypesMatch<["value1", "value2", "result"]>
@ -1435,7 +1487,6 @@ def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
//===----------------------------------------------------------------------===//
// TypesMatchWith type inference
//===----------------------------------------------------------------------===//
def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
TypesMatchWith<"result type matches operand", "value", "result", "$_self">

View File

@ -42,6 +42,49 @@ def DirectiveAttrDictValidB : TestFormat_Op<"attrdict_valid_b", [{
attr-dict-with-keyword
}]>;
//===----------------------------------------------------------------------===//
// custom
// CHECK: error: expected '<' before custom directive name
def DirectiveCustomInvalidA : TestFormat_Op<"custom_invalid_a", [{
custom(
}]>;
// CHECK: error: expected custom directive name identifier
def DirectiveCustomInvalidB : TestFormat_Op<"custom_invalid_b", [{
custom<>
}]>;
// CHECK: error: expected '>' after custom directive name
def DirectiveCustomInvalidC : TestFormat_Op<"custom_invalid_c", [{
custom<MyDirective(
}]>;
// CHECK: error: expected '(' before custom directive parameters
def DirectiveCustomInvalidD : TestFormat_Op<"custom_invalid_d", [{
custom<MyDirective>)
}]>;
// CHECK: error: only variables and types may be used as parameters to a custom directive
def DirectiveCustomInvalidE : TestFormat_Op<"custom_invalid_e", [{
custom<MyDirective>(operands)
}]>;
// CHECK: error: expected ')' after custom directive parameters
def DirectiveCustomInvalidF : TestFormat_Op<"custom_invalid_f", [{
custom<MyDirective>($operand<
}]>, Arguments<(ins I64:$operand)>;
// CHECK: error: type directives within a custom directive may only refer to variables
def DirectiveCustomInvalidH : TestFormat_Op<"custom_invalid_h", [{
custom<MyDirective>(type(operands))
}]>;
// CHECK-NOT: error
def DirectiveCustomValidA : TestFormat_Op<"custom_valid_a", [{
custom<MyDirective>($operand) attr-dict
}]>, Arguments<(ins Optional<I64>:$operand)>;
def DirectiveCustomValidB : TestFormat_Op<"custom_valid_b", [{
custom<MyDirective>($operand, type($operand), type($result)) attr-dict
}]>, Arguments<(ins I64:$operand)>, Results<(outs I64:$result)>;
def DirectiveCustomValidC : TestFormat_Op<"custom_valid_c", [{
custom<MyDirective>($attr) attr-dict
}]>, Arguments<(ins I64Attr:$attr)>;
//===----------------------------------------------------------------------===//
// functional-type
@ -238,6 +281,10 @@ def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
($arg^)
}]>, Arguments<(ins Variadic<I64>:$arg)>;
// CHECK: error: only variables can be used to anchor an optional group
def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
(custom<MyDirective>($arg)^)?
}]>, Arguments<(ins I64:$arg)>;
//===----------------------------------------------------------------------===//
// Variables

View File

@ -122,6 +122,40 @@ test.format_optional_operand_result_b_op( : ) : i64
// CHECK: test.format_optional_operand_result_b_op : i64
test.format_optional_operand_result_b_op : i64
//===----------------------------------------------------------------------===//
// Format custom directives
//===----------------------------------------------------------------------===//
// CHECK: test.format_custom_directive_operands %[[I64]], %[[I64]] -> (%[[I64]])
test.format_custom_directive_operands %i64, %i64 -> (%i64)
// CHECK: test.format_custom_directive_operands %[[I64]] -> (%[[I64]])
test.format_custom_directive_operands %i64 -> (%i64)
// CHECK: test.format_custom_directive_operands_and_types %[[I64]], %[[I64]] -> (%[[I64]]) : i64, i64 -> (i64)
test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64 -> (i64)
// CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64)
test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64)
// CHECK: test.format_custom_directive_results : i64, i64 -> (i64)
test.format_custom_directive_results : i64, i64 -> (i64)
// CHECK: test.format_custom_directive_results : i64 -> (i64)
test.format_custom_directive_results : i64 -> (i64)
func @foo() {
// CHECK: test.format_custom_directive_successors ^bb1, ^bb2
test.format_custom_directive_successors ^bb1, ^bb2
^bb1:
// CHECK: test.format_custom_directive_successors ^bb2
test.format_custom_directive_successors ^bb2
^bb2:
return
}
//===----------------------------------------------------------------------===//
// Format trait type inference
//===----------------------------------------------------------------------===//

View File

@ -45,6 +45,7 @@ public:
enum class Kind {
/// This element is a directive.
AttrDictDirective,
CustomDirective,
FunctionalTypeDirective,
OperandsDirective,
ResultsDirective,
@ -132,8 +133,7 @@ using SuccessorVariable =
namespace {
/// This class implements single kind directives.
template <Element::Kind type>
class DirectiveElement : public Element {
template <Element::Kind type> class DirectiveElement : public Element {
public:
DirectiveElement() : Element(type){};
static bool classof(const Element *ele) { return ele->getKind() == type; }
@ -164,6 +164,33 @@ private:
bool withKeyword;
};
/// This class represents a custom format directive that is implemented by the
/// user in C++.
class CustomDirective : public Element {
public:
CustomDirective(StringRef name,
std::vector<std::unique_ptr<Element>> &&arguments)
: Element{Kind::CustomDirective}, name(name),
arguments(std::move(arguments)) {}
static bool classof(const Element *element) {
return element->getKind() == Kind::CustomDirective;
}
/// Return the name of this optional element.
StringRef getName() const { return name; }
/// Return the arguments to the custom directive.
auto getArguments() const { return llvm::make_pointee_range(arguments); }
private:
/// The user provided name of the directive.
StringRef name;
/// The arguments to the custom directive.
std::vector<std::unique_ptr<Element>> arguments;
};
/// This class represents the `functional-type` directive. This directive takes
/// two arguments and formats them, respectively, as the inputs and results of a
/// FunctionType.
@ -370,19 +397,16 @@ static bool canFormatEnumAttr(const NamedAttribute *attr) {
/// The code snippet used to generate a parser call for an attribute.
///
/// {0}: The storage type of the attribute.
/// {1}: The name of the attribute.
/// {2}: The type for the attribute.
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const attrParserCode = R"(
{0} {1}Attr;
if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
return failure();
)";
const char *const optionalAttrParserCode = R"(
{0} {1}Attr;
{
::mlir::OptionalParseResult parseResult =
parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes);
parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
if (parseResult.hasValue() && failed(*parseResult))
return failure();
}
@ -408,11 +432,11 @@ const char *const enumAttrParserCode = R"(
return parser.emitError(loc, "invalid ")
<< "{0} attribute specification: " << attrVal;
result.addAttribute("{0}", {3});
{0}Attr = {3};
result.addAttribute("{0}", {0}Attr);
}
)";
const char *const optionalEnumAttrParserCode = R"(
Attribute {0}Attr;
{
::mlir::StringAttr attrVal;
::mlir::NamedAttrList attrStorage;
@ -440,11 +464,13 @@ const char *const optionalEnumAttrParserCode = R"(
///
/// {0}: The name of the operand.
const char *const variadicOperandParserCode = R"(
{0}OperandsLoc = parser.getCurrentLocation();
if (parser.parseOperandList({0}Operands))
return failure();
)";
const char *const optionalOperandParserCode = R"(
{
{0}OperandsLoc = parser.getCurrentLocation();
::mlir::OpAsmParser::OperandType operand;
::mlir::OptionalParseResult parseResult =
parser.parseOptionalOperand(operand);
@ -456,6 +482,7 @@ const char *const optionalOperandParserCode = R"(
}
)";
const char *const operandParserCode = R"(
{0}OperandsLoc = parser.getCurrentLocation();
if (parser.parseOperand({0}RawOperands[0]))
return failure();
)";
@ -500,7 +527,6 @@ const char *const functionalTypeParserCode = R"(
///
/// {0}: The name for the successor list.
const char *successorListParserCode = R"(
::llvm::SmallVector<::mlir::Block *, 2> {0}Successors;
{
::mlir::Block *succ;
auto firstSucc = parser.parseOptionalSuccessor(succ);
@ -523,7 +549,6 @@ const char *successorListParserCode = R"(
///
/// {0}: The name of the successor.
const char *successorParserCode = R"(
::mlir::Block *{0}Successor = nullptr;
if (parser.parseSuccessor({0}Successor))
return failure();
)";
@ -595,8 +620,34 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
/// Generate the storage code required for parsing the given element.
static void genElementParserStorage(Element *element, OpMethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
for (auto &childElement : optional->getElements())
genElementParserStorage(&childElement, body);
auto elements = optional->getElements();
// If the anchor is a unit attribute, it won't be parsed directly so elide
// it.
auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
Element *elidedAnchorElement = nullptr;
if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
elidedAnchorElement = anchor;
for (auto &childElement : elements)
if (&childElement != elidedAnchorElement)
genElementParserStorage(&childElement, body);
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
for (auto &paramElement : custom->getArguments())
genElementParserStorage(&paramElement, body);
} else if (isa<OperandsDirective>(element)) {
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
"allOperands;\n";
} else if (isa<SuccessorsDirective>(element)) {
body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
const NamedAttribute *var = attr->getVar();
body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name;
if (operand->getVar()->isVariableLength()) {
@ -608,10 +659,19 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
<< "Operands(" << name << "RawOperands);";
}
body << llvm::formatv(
" ::llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();\n"
" (void){0}OperandsLoc;\n",
name);
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
" (void){0}OperandsLoc;\n",
name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic()) {
body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
"{0}Successors;\n",
name);
} else {
body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
}
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef name = getTypeListName(dir->getOperand(), lengthKind);
@ -631,6 +691,106 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
}
}
/// Generate the parser for a parameter to a custom directive.
static void genCustomParameterParser(Element &param, OpMethodBody &body) {
body << ", ";
if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
body << attr->getVar()->name << "Attr";
} else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
StringRef name = operand->getVar()->name;
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Operands", name);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Operand", name);
else
body << formatv("{0}RawOperands[0]", name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic())
body << llvm::formatv("{0}Successors", name);
else
body << llvm::formatv("{0}Successor", name);
} else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Types", listName);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Type", listName);
else
body << formatv("{0}RawTypes[0]", listName);
} else {
llvm_unreachable("unknown custom directive parameter");
}
}
/// Generate the parser for a custom directive.
static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
body << " {\n";
// Preprocess the directive variables.
// * Add a local variable for optional operands and types. This provides a
// better API to the user defined parser methods.
// * Set the location of operand variables.
for (Element &param : dir->getArguments()) {
if (auto *operand = dyn_cast<OperandVariable>(&param)) {
body << " " << operand->getVar()->name
<< "OperandsLoc = parser.getCurrentLocation();\n";
if (operand->getVar()->isOptional()) {
body << llvm::formatv(
" llvm::Optional<::mlir::OpAsmParser::OperandType> "
"{0}Operand;\n",
operand->getVar()->name);
}
} else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
}
}
body << " if (parse" << dir->getName() << "(parser";
for (Element &param : dir->getArguments())
genCustomParameterParser(param, body);
body << "))\n"
<< " return failure();\n";
// After parsing, add handling for any of the optional constructs.
for (Element &param : dir->getArguments()) {
if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional())
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
body << llvm::formatv(
" result.attributes.addAttribute(\"{0}\", {0}Attr);", var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
const NamedTypeConstraint *var = operand->getVar();
if (!var->isOptional())
continue;
body << llvm::formatv(" if ({0}Operand.hasValue())\n"
" {0}Operands.push_back(*{0}Operand);\n",
var->name);
} else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" if ({0}Type)\n"
" {0}Types.push_back({0}Type);\n",
listName);
}
}
}
body << " }\n";
}
/// Generate the parser for a single format element.
static void genElementParser(Element *element, OpMethodBody &body,
FmtContext &attrTypeCtx) {
@ -711,7 +871,7 @@ static void genElementParser(Element *element, OpMethodBody &body,
body << formatv(var->attr.isOptional() ? optionalAttrParserCode
: attrParserCode,
var->attr.getStorageType(), var->name, attrTypeStr);
var->name, attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
@ -732,10 +892,11 @@ static void genElementParser(Element *element, OpMethodBody &body,
<< (attrDict->isWithKeyword() ? "WithKeyword" : "")
<< "(result.attributes))\n"
<< " return failure();\n";
} else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
genCustomDirectiveParser(customDir, body);
} else if (isa<OperandsDirective>(element)) {
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
<< " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
"allOperands;\n"
<< " if (parser.parseOperandList(allOperands))\n"
<< " return failure();\n";
} else if (isa<SuccessorsDirective>(element)) {
@ -980,6 +1141,20 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
llvm::interleaveComma(op.getOperands(), body, interleaveFn);
body << "}));\n";
}
if (!allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments")) {
body << " result.addAttribute(\"result_segment_sizes\", "
<< "parser.getBuilder().getI32VectorAttr({";
auto interleaveFn = [&](const NamedTypeConstraint &result) {
// If the result is variadic emit the parsed size.
if (result.isVariableLength())
body << "static_cast<int32_t>(" << result.name << "Types.size())";
else
body << "1";
};
llvm::interleaveComma(op.getResults(), body, interleaveFn);
body << "}));\n";
}
}
//===----------------------------------------------------------------------===//
@ -1007,6 +1182,8 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
// Elide the variadic segment size attributes if necessary.
if (!fmt.allOperands && op.getTrait("OpTrait::AttrSizedOperandSegments"))
body << "\"operand_segment_sizes\", ";
if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
body << "\"result_segment_sizes\", ";
llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
body << "\"" << attr->name << "\"";
});
@ -1038,6 +1215,42 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
}
/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
/// space should be emitted before this element. `lastWasPunctuation` is true if
/// the previous element was a punctuation literal.
static void genCustomDirectivePrinter(CustomDirective *customDir,
OpMethodBody &body) {
body << " print" << customDir->getName() << "(p";
for (Element &param : customDir->getArguments()) {
body << ", ";
if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
body << attr->getVar()->name << "Attr()";
} else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
body << operand->getVar()->name << "()";
} else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
body << successor->getVar()->name << "()";
} else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
auto *typeOperand = dir->getOperand();
auto *operand = dyn_cast<OperandVariable>(typeOperand);
auto *var = operand ? operand->getVar()
: cast<ResultVariable>(typeOperand)->getVar();
if (var->isVariadic())
body << var->name << "().getTypes()";
else if (var->isOptional())
body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
else
body << var->name << "().getType()";
} else {
llvm_unreachable("unknown custom directive parameter");
}
}
body << ");\n";
}
/// Generate the C++ for an operand to a (*-)type directive.
static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
if (isa<OperandsDirective>(arg))
@ -1145,6 +1358,8 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
body << " ::llvm::interleaveComma(" << var->name << "(), p);\n";
else
body << " p << " << var->name << "();\n";
} else if (auto *dir = dyn_cast<CustomDirective>(element)) {
genCustomDirectivePrinter(dir, body);
} else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n";
} else if (isa<SuccessorsDirective>(element)) {
@ -1202,12 +1417,15 @@ public:
caret,
comma,
equal,
less,
greater,
question,
// Keywords.
keyword_start,
kw_attr_dict,
kw_attr_dict_w_keyword,
kw_custom,
kw_functional_type,
kw_operands,
kw_results,
@ -1353,6 +1571,10 @@ Token FormatLexer::lexToken() {
return formToken(Token::comma, tokStart);
case '=':
return formToken(Token::equal, tokStart);
case '<':
return formToken(Token::less, tokStart);
case '>':
return formToken(Token::greater, tokStart);
case '?':
return formToken(Token::question, tokStart);
case '(':
@ -1406,6 +1628,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
llvm::StringSwitch<Token::Kind>(str)
.Case("attr-dict", Token::kw_attr_dict)
.Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
.Case("custom", Token::kw_custom)
.Case("functional-type", Token::kw_functional_type)
.Case("operands", Token::kw_operands)
.Case("results", Token::kw_results)
@ -1421,8 +1644,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
/// Function to find an element within the given range that has the same name as
/// 'name'.
template <typename RangeT>
static auto findArg(RangeT &&range, StringRef name) {
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
}
@ -1513,6 +1735,10 @@ private:
LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel,
bool withKeyword);
LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseCustomDirectiveParameter(
std::vector<std::unique_ptr<Element>> &parameters);
LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
Token tok, bool isTopLevel);
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
@ -1930,6 +2156,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
case Token::kw_attr_dict_w_keyword:
return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
/*withKeyword=*/true);
case Token::kw_custom:
return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_functional_type:
return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
case Token::kw_operands:
@ -2054,15 +2282,15 @@ LogicalResult FormatParser::parseOptionalChildElement(
seenVariables.insert(ele->getVar());
return success();
})
// Literals and type directives may be used, but they can't anchor the
// group.
.Case<LiteralElement, TypeDirective, FunctionalTypeDirective>(
[&](Element *) {
if (isAnchor)
return emitError(childLoc, "only variables can be used to anchor "
"an optional group");
return success();
})
// Literals, custom directives, and type directives may be used,
// but they can't anchor the group.
.Case<LiteralElement, CustomDirective, TypeDirective,
FunctionalTypeDirective>([&](Element *) {
if (isAnchor)
return emitError(childLoc, "only variables can be used to anchor "
"an optional group");
return success();
})
.Default([&](Element *) {
return emitError(childLoc, "only literals, types, and variables can be "
"used within an optional group");
@ -2084,6 +2312,71 @@ FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
return success();
}
LogicalResult
FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel) {
llvm::SMLoc curLoc = curToken.getLoc();
// Parse the custom directive name.
if (failed(
parseToken(Token::less, "expected '<' before custom directive name")))
return failure();
Token nameTok = curToken;
if (failed(parseToken(Token::identifier,
"expected custom directive name identifier")) ||
failed(parseToken(Token::greater,
"expected '>' after custom directive name")) ||
failed(parseToken(Token::l_paren,
"expected '(' before custom directive parameters")))
return failure();
// Parse the child elements for this optional group.=
std::vector<std::unique_ptr<Element>> elements;
do {
if (failed(parseCustomDirectiveParameter(elements)))
return failure();
if (curToken.getKind() != Token::comma)
break;
consumeToken();
} while (true);
if (failed(parseToken(Token::r_paren,
"expected ')' after custom directive parameters")))
return failure();
// After parsing all of the elements, ensure that all type directives refer
// only to variables.
for (auto &ele : elements) {
if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
return emitError(curLoc, "type directives within a custom directive "
"may only refer to variables");
}
}
}
element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
std::move(elements));
return success();
}
LogicalResult FormatParser::parseCustomDirectiveParameter(
std::vector<std::unique_ptr<Element>> &parameters) {
llvm::SMLoc childLoc = curToken.getLoc();
parameters.push_back({});
if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
return failure();
// Verify that the element can be placed within a custom directive.
if (!isa<TypeDirective, AttributeVariable, OperandVariable,
SuccessorVariable>(parameters.back().get())) {
return emitError(childLoc, "only variables and types may be used as "
"parameters to a custom directive");
}
return success();
}
LogicalResult
FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
Token tok, bool isTopLevel) {