forked from OSchip/llvm-project
[mlir][OpFormatGen] Add support for specifiy "custom" directives.
This revision adds support for custom directives to the declarative assembly format. This allows for users to use C++ for printing and parsing subsections of an otherwise declaratively specified format. The custom directive is structured as follows: ``` custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)` ``` `user-directive` is used as a suffix when this directive is used during printing and parsing. When parsing, `parseUserDirective` will be invoked. When printing, `printUserDirective` will be invoked. The first parameter to these methods must be a reference to either the OpAsmParser, or OpAsmPrinter. The type of rest of the parameters is dependent on the `Params` specified in the assembly format. Differential Revision: https://reviews.llvm.org/D84719
This commit is contained in:
parent
61e15ecab5
commit
88c6e25e4f
|
@ -664,6 +664,12 @@ The available directives are as follows:
|
|||
- Represents the attribute dictionary of the operation, but prefixes the
|
||||
dictionary with an `attributes` keyword.
|
||||
|
||||
* `custom` < UserDirective > ( Params )
|
||||
|
||||
- Represents a custom directive implemented by the user in C++.
|
||||
- See the [Custom Directives](#custom-directives) section below for more
|
||||
details.
|
||||
|
||||
* `functional-type` ( inputs , results )
|
||||
|
||||
- Formats the `inputs` and `results` arguments as a
|
||||
|
@ -705,6 +711,75 @@ example above, the variables would be `$callee` and `$args`.
|
|||
Attribute variables are printed with their respective value type, unless that
|
||||
value type is buildable. In those cases, the type of the attribute is elided.
|
||||
|
||||
#### Custom Directives
|
||||
|
||||
The declarative assembly format specification allows for handling a large
|
||||
majority of the common cases when formatting an operation. For the operations
|
||||
that require or desire specifying parts of the operation in a form not supported
|
||||
by the declarative syntax, custom directives may be specified. A custom
|
||||
directive essentially allows for users to use C++ for printing and parsing
|
||||
subsections of an otherwise declaratively specified format. Looking at the
|
||||
specification of a custom directive above:
|
||||
|
||||
```
|
||||
custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)`
|
||||
```
|
||||
|
||||
A custom directive has two main parts: The `UserDirective` and the `Params`. A
|
||||
custom directive is transformed into a call to a `print*` and a `parse*` method
|
||||
when generating the C++ code for the format. The `UserDirective` is an
|
||||
identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
|
||||
would result in calls to `parseMyDirective` and `printMyDirective` wihtin the
|
||||
parser and printer respectively. `Params` may be any combination of variables
|
||||
(i.e. Attribute, Operand, Successor, etc.) and type directives. The type
|
||||
directives must refer to a variable, but that variable need not also be a
|
||||
parameter to the custom directive.
|
||||
|
||||
The arguments to the `parse<UserDirective>` method is firstly a reference to the
|
||||
`OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
|
||||
corresponding to the parameters specified in the format. The mapping of
|
||||
declarative parameter to `parse` method argument is detailed below:
|
||||
|
||||
* Attribute Variables
|
||||
- Single: `<Attribute-Storage-Type>(e.g. Attribute) &`
|
||||
- Optional: `<Attribute-Storage-Type>(e.g. Attribute) &`
|
||||
* Operand Variables
|
||||
- Single: `OpAsmParser::OperandType &`
|
||||
- Optional: `Optional<OpAsmParser::OperandType> &`
|
||||
- Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
|
||||
* Successor Variables
|
||||
- Single: `Block *&`
|
||||
- Variadic: `SmallVectorImpl<Block *> &`
|
||||
* Type Directives
|
||||
- Single: `Type &`
|
||||
- Optional: `Type &`
|
||||
- Variadic: `SmallVectorImpl<Type> &`
|
||||
|
||||
When a variable is optional, the value should only be specified if the variable
|
||||
is present. Otherwise, the value should remain `None` or null.
|
||||
|
||||
The arguments to the `print<UserDirective>` method is firstly a reference to the
|
||||
`OpAsmPrinter`(`OpAsmPrinter &`), and secondly a set of output parameters
|
||||
corresponding to the parameters specified in the format. The mapping of
|
||||
declarative parameter to `print` method argument is detailed below:
|
||||
|
||||
* Attribute Variables
|
||||
- Single: `<Attribute-Storage-Type>(e.g. Attribute)`
|
||||
- Optional: `<Attribute-Storage-Type>(e.g. Attribute)`
|
||||
* Operand Variables
|
||||
- Single: `Value`
|
||||
- Optional: `Value`
|
||||
- Variadic: `OperandRange`
|
||||
* Successor Variables
|
||||
- Single: `Block *`
|
||||
- Variadic: `SuccessorRange`
|
||||
* Type Directives
|
||||
- Single: `Type`
|
||||
- Optional: `Type`
|
||||
- Variadic: `TypeRange`
|
||||
|
||||
When a variable is optional, the provided value may be null.
|
||||
|
||||
#### Optional Groups
|
||||
|
||||
In certain situations operations may have "optional" information, e.g.
|
||||
|
@ -722,8 +797,8 @@ information. An optional group is defined by wrapping a set of elements within
|
|||
should be printed/parsed.
|
||||
- An element is marked as the anchor by adding a trailing `^`.
|
||||
- The first element is *not* required to be the anchor of the group.
|
||||
* Literals, variables, and type directives are the only valid elements within
|
||||
the group.
|
||||
* Literals, variables, custom directives, and type directives are the only
|
||||
valid elements within the group.
|
||||
- Any attribute variable may be used, but only optional attributes can be
|
||||
marked as the anchor.
|
||||
- Only variadic or optional operand arguments can be used.
|
||||
|
|
|
@ -202,6 +202,10 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
|
|||
llvm::interleaveComma(types, p);
|
||||
return p;
|
||||
}
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) {
|
||||
llvm::interleaveComma(types, p);
|
||||
return p;
|
||||
}
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
|
||||
llvm::interleaveComma(types, p);
|
||||
return p;
|
||||
|
|
|
@ -267,6 +267,108 @@ void FoldToCallOp::getCanonicalizationPatterns(
|
|||
results.insert<FoldToCallOpPattern>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Format* operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Parsing
|
||||
|
||||
static ParseResult parseCustomDirectiveOperands(
|
||||
OpAsmParser &parser, OpAsmParser::OperandType &operand,
|
||||
Optional<OpAsmParser::OperandType> &optOperand,
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
|
||||
if (parser.parseOperand(operand))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalComma())) {
|
||||
optOperand.emplace();
|
||||
if (parser.parseOperand(*optOperand))
|
||||
return failure();
|
||||
}
|
||||
if (parser.parseArrow() || parser.parseLParen() ||
|
||||
parser.parseOperandList(varOperands) || parser.parseRParen())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
static ParseResult
|
||||
parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
|
||||
Type &optOperandType,
|
||||
SmallVectorImpl<Type> &varOperandTypes) {
|
||||
if (parser.parseColon())
|
||||
return failure();
|
||||
|
||||
if (parser.parseType(operandType))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseType(optOperandType))
|
||||
return failure();
|
||||
}
|
||||
if (parser.parseArrow() || parser.parseLParen() ||
|
||||
parser.parseTypeList(varOperandTypes) || parser.parseRParen())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
static ParseResult parseCustomDirectiveOperandsAndTypes(
|
||||
OpAsmParser &parser, OpAsmParser::OperandType &operand,
|
||||
Optional<OpAsmParser::OperandType> &optOperand,
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
|
||||
Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
|
||||
if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
|
||||
parseCustomDirectiveResults(parser, operandType, optOperandType,
|
||||
varOperandTypes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
static ParseResult
|
||||
parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
|
||||
SmallVectorImpl<Block *> &varSuccessors) {
|
||||
if (parser.parseSuccessor(successor))
|
||||
return failure();
|
||||
if (failed(parser.parseOptionalComma()))
|
||||
return success();
|
||||
Block *varSuccessor;
|
||||
if (parser.parseSuccessor(varSuccessor))
|
||||
return failure();
|
||||
varSuccessors.append(2, varSuccessor);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing
|
||||
|
||||
static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand,
|
||||
Value optOperand,
|
||||
OperandRange varOperands) {
|
||||
printer << operand;
|
||||
if (optOperand)
|
||||
printer << ", " << optOperand;
|
||||
printer << " -> (" << varOperands << ")";
|
||||
}
|
||||
static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
|
||||
Type optOperandType,
|
||||
TypeRange varOperandTypes) {
|
||||
printer << " : " << operandType;
|
||||
if (optOperandType)
|
||||
printer << ", " << optOperandType;
|
||||
printer << " -> (" << varOperandTypes << ")";
|
||||
}
|
||||
static void
|
||||
printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
|
||||
Value optOperand, OperandRange varOperands,
|
||||
Type operandType, Type optOperandType,
|
||||
TypeRange varOperandTypes) {
|
||||
printCustomDirectiveOperands(printer, operand, optOperand, varOperands);
|
||||
printCustomDirectiveResults(printer, operandType, optOperandType,
|
||||
varOperandTypes);
|
||||
}
|
||||
static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
|
||||
Block *successor,
|
||||
SuccessorRange varSuccessors) {
|
||||
printer << successor;
|
||||
if (!varSuccessors.empty())
|
||||
printer << ", " << varSuccessors.front();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test IsolatedRegionOp - parse passthrough region arguments.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1414,8 +1414,60 @@ def FormatOptionalUnitAttrNoElide
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllTypesMatch type inference
|
||||
// Custom Directives
|
||||
|
||||
def FormatCustomDirectiveOperands
|
||||
: TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> {
|
||||
let arguments = (ins I64:$operand, Optional<I64>:$optOperand,
|
||||
Variadic<I64>:$varOperands);
|
||||
let assemblyFormat = [{
|
||||
custom<CustomDirectiveOperands>(
|
||||
$operand, $optOperand, $varOperands
|
||||
)
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def FormatCustomDirectiveOperandsAndTypes
|
||||
: TEST_Op<"format_custom_directive_operands_and_types",
|
||||
[AttrSizedOperandSegments]> {
|
||||
let arguments = (ins AnyType:$operand, Optional<AnyType>:$optOperand,
|
||||
Variadic<AnyType>:$varOperands);
|
||||
let assemblyFormat = [{
|
||||
custom<CustomDirectiveOperandsAndTypes>(
|
||||
$operand, $optOperand, $varOperands,
|
||||
type($operand), type($optOperand), type($varOperands)
|
||||
)
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def FormatCustomDirectiveResults
|
||||
: TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
|
||||
let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
|
||||
Variadic<AnyType>:$varResults);
|
||||
let assemblyFormat = [{
|
||||
custom<CustomDirectiveResults>(
|
||||
type($result), type($optResult), type($varResults)
|
||||
)
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def FormatCustomDirectiveSuccessors
|
||||
: TEST_Op<"format_custom_directive_successors", [Terminator]> {
|
||||
let successors = (successor AnySuccessor:$successor,
|
||||
VariadicSuccessor<AnySuccessor>:$successors);
|
||||
let assemblyFormat = [{
|
||||
custom<CustomDirectiveSuccessors>(
|
||||
$successor, $successors
|
||||
)
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllTypesMatch type inference
|
||||
|
||||
def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
|
||||
AllTypesMatch<["value1", "value2", "result"]>
|
||||
|
@ -1435,7 +1487,6 @@ def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypesMatchWith type inference
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
|
||||
TypesMatchWith<"result type matches operand", "value", "result", "$_self">
|
||||
|
|
|
@ -42,6 +42,49 @@ def DirectiveAttrDictValidB : TestFormat_Op<"attrdict_valid_b", [{
|
|||
attr-dict-with-keyword
|
||||
}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// custom
|
||||
|
||||
// CHECK: error: expected '<' before custom directive name
|
||||
def DirectiveCustomInvalidA : TestFormat_Op<"custom_invalid_a", [{
|
||||
custom(
|
||||
}]>;
|
||||
// CHECK: error: expected custom directive name identifier
|
||||
def DirectiveCustomInvalidB : TestFormat_Op<"custom_invalid_b", [{
|
||||
custom<>
|
||||
}]>;
|
||||
// CHECK: error: expected '>' after custom directive name
|
||||
def DirectiveCustomInvalidC : TestFormat_Op<"custom_invalid_c", [{
|
||||
custom<MyDirective(
|
||||
}]>;
|
||||
// CHECK: error: expected '(' before custom directive parameters
|
||||
def DirectiveCustomInvalidD : TestFormat_Op<"custom_invalid_d", [{
|
||||
custom<MyDirective>)
|
||||
}]>;
|
||||
// CHECK: error: only variables and types may be used as parameters to a custom directive
|
||||
def DirectiveCustomInvalidE : TestFormat_Op<"custom_invalid_e", [{
|
||||
custom<MyDirective>(operands)
|
||||
}]>;
|
||||
// CHECK: error: expected ')' after custom directive parameters
|
||||
def DirectiveCustomInvalidF : TestFormat_Op<"custom_invalid_f", [{
|
||||
custom<MyDirective>($operand<
|
||||
}]>, Arguments<(ins I64:$operand)>;
|
||||
// CHECK: error: type directives within a custom directive may only refer to variables
|
||||
def DirectiveCustomInvalidH : TestFormat_Op<"custom_invalid_h", [{
|
||||
custom<MyDirective>(type(operands))
|
||||
}]>;
|
||||
|
||||
// CHECK-NOT: error
|
||||
def DirectiveCustomValidA : TestFormat_Op<"custom_valid_a", [{
|
||||
custom<MyDirective>($operand) attr-dict
|
||||
}]>, Arguments<(ins Optional<I64>:$operand)>;
|
||||
def DirectiveCustomValidB : TestFormat_Op<"custom_valid_b", [{
|
||||
custom<MyDirective>($operand, type($operand), type($result)) attr-dict
|
||||
}]>, Arguments<(ins I64:$operand)>, Results<(outs I64:$result)>;
|
||||
def DirectiveCustomValidC : TestFormat_Op<"custom_valid_c", [{
|
||||
custom<MyDirective>($attr) attr-dict
|
||||
}]>, Arguments<(ins I64Attr:$attr)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// functional-type
|
||||
|
||||
|
@ -238,6 +281,10 @@ def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
|
|||
def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
|
||||
($arg^)
|
||||
}]>, Arguments<(ins Variadic<I64>:$arg)>;
|
||||
// CHECK: error: only variables can be used to anchor an optional group
|
||||
def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
|
||||
(custom<MyDirective>($arg)^)?
|
||||
}]>, Arguments<(ins I64:$arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Variables
|
||||
|
|
|
@ -122,6 +122,40 @@ test.format_optional_operand_result_b_op( : ) : i64
|
|||
// CHECK: test.format_optional_operand_result_b_op : i64
|
||||
test.format_optional_operand_result_b_op : i64
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Format custom directives
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: test.format_custom_directive_operands %[[I64]], %[[I64]] -> (%[[I64]])
|
||||
test.format_custom_directive_operands %i64, %i64 -> (%i64)
|
||||
|
||||
// CHECK: test.format_custom_directive_operands %[[I64]] -> (%[[I64]])
|
||||
test.format_custom_directive_operands %i64 -> (%i64)
|
||||
|
||||
// CHECK: test.format_custom_directive_operands_and_types %[[I64]], %[[I64]] -> (%[[I64]]) : i64, i64 -> (i64)
|
||||
test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64 -> (i64)
|
||||
|
||||
// CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64)
|
||||
test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64)
|
||||
|
||||
// CHECK: test.format_custom_directive_results : i64, i64 -> (i64)
|
||||
test.format_custom_directive_results : i64, i64 -> (i64)
|
||||
|
||||
// CHECK: test.format_custom_directive_results : i64 -> (i64)
|
||||
test.format_custom_directive_results : i64 -> (i64)
|
||||
|
||||
func @foo() {
|
||||
// CHECK: test.format_custom_directive_successors ^bb1, ^bb2
|
||||
test.format_custom_directive_successors ^bb1, ^bb2
|
||||
|
||||
^bb1:
|
||||
// CHECK: test.format_custom_directive_successors ^bb2
|
||||
test.format_custom_directive_successors ^bb2
|
||||
|
||||
^bb2:
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Format trait type inference
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -45,6 +45,7 @@ public:
|
|||
enum class Kind {
|
||||
/// This element is a directive.
|
||||
AttrDictDirective,
|
||||
CustomDirective,
|
||||
FunctionalTypeDirective,
|
||||
OperandsDirective,
|
||||
ResultsDirective,
|
||||
|
@ -132,8 +133,7 @@ using SuccessorVariable =
|
|||
|
||||
namespace {
|
||||
/// This class implements single kind directives.
|
||||
template <Element::Kind type>
|
||||
class DirectiveElement : public Element {
|
||||
template <Element::Kind type> class DirectiveElement : public Element {
|
||||
public:
|
||||
DirectiveElement() : Element(type){};
|
||||
static bool classof(const Element *ele) { return ele->getKind() == type; }
|
||||
|
@ -164,6 +164,33 @@ private:
|
|||
bool withKeyword;
|
||||
};
|
||||
|
||||
/// This class represents a custom format directive that is implemented by the
|
||||
/// user in C++.
|
||||
class CustomDirective : public Element {
|
||||
public:
|
||||
CustomDirective(StringRef name,
|
||||
std::vector<std::unique_ptr<Element>> &&arguments)
|
||||
: Element{Kind::CustomDirective}, name(name),
|
||||
arguments(std::move(arguments)) {}
|
||||
|
||||
static bool classof(const Element *element) {
|
||||
return element->getKind() == Kind::CustomDirective;
|
||||
}
|
||||
|
||||
/// Return the name of this optional element.
|
||||
StringRef getName() const { return name; }
|
||||
|
||||
/// Return the arguments to the custom directive.
|
||||
auto getArguments() const { return llvm::make_pointee_range(arguments); }
|
||||
|
||||
private:
|
||||
/// The user provided name of the directive.
|
||||
StringRef name;
|
||||
|
||||
/// The arguments to the custom directive.
|
||||
std::vector<std::unique_ptr<Element>> arguments;
|
||||
};
|
||||
|
||||
/// This class represents the `functional-type` directive. This directive takes
|
||||
/// two arguments and formats them, respectively, as the inputs and results of a
|
||||
/// FunctionType.
|
||||
|
@ -370,19 +397,16 @@ static bool canFormatEnumAttr(const NamedAttribute *attr) {
|
|||
|
||||
/// The code snippet used to generate a parser call for an attribute.
|
||||
///
|
||||
/// {0}: The storage type of the attribute.
|
||||
/// {1}: The name of the attribute.
|
||||
/// {2}: The type for the attribute.
|
||||
/// {0}: The name of the attribute.
|
||||
/// {1}: The type for the attribute.
|
||||
const char *const attrParserCode = R"(
|
||||
{0} {1}Attr;
|
||||
if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
|
||||
if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
|
||||
return failure();
|
||||
)";
|
||||
const char *const optionalAttrParserCode = R"(
|
||||
{0} {1}Attr;
|
||||
{
|
||||
::mlir::OptionalParseResult parseResult =
|
||||
parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes);
|
||||
parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
|
||||
if (parseResult.hasValue() && failed(*parseResult))
|
||||
return failure();
|
||||
}
|
||||
|
@ -408,11 +432,11 @@ const char *const enumAttrParserCode = R"(
|
|||
return parser.emitError(loc, "invalid ")
|
||||
<< "{0} attribute specification: " << attrVal;
|
||||
|
||||
result.addAttribute("{0}", {3});
|
||||
{0}Attr = {3};
|
||||
result.addAttribute("{0}", {0}Attr);
|
||||
}
|
||||
)";
|
||||
const char *const optionalEnumAttrParserCode = R"(
|
||||
Attribute {0}Attr;
|
||||
{
|
||||
::mlir::StringAttr attrVal;
|
||||
::mlir::NamedAttrList attrStorage;
|
||||
|
@ -440,11 +464,13 @@ const char *const optionalEnumAttrParserCode = R"(
|
|||
///
|
||||
/// {0}: The name of the operand.
|
||||
const char *const variadicOperandParserCode = R"(
|
||||
{0}OperandsLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList({0}Operands))
|
||||
return failure();
|
||||
)";
|
||||
const char *const optionalOperandParserCode = R"(
|
||||
{
|
||||
{0}OperandsLoc = parser.getCurrentLocation();
|
||||
::mlir::OpAsmParser::OperandType operand;
|
||||
::mlir::OptionalParseResult parseResult =
|
||||
parser.parseOptionalOperand(operand);
|
||||
|
@ -456,6 +482,7 @@ const char *const optionalOperandParserCode = R"(
|
|||
}
|
||||
)";
|
||||
const char *const operandParserCode = R"(
|
||||
{0}OperandsLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperand({0}RawOperands[0]))
|
||||
return failure();
|
||||
)";
|
||||
|
@ -500,7 +527,6 @@ const char *const functionalTypeParserCode = R"(
|
|||
///
|
||||
/// {0}: The name for the successor list.
|
||||
const char *successorListParserCode = R"(
|
||||
::llvm::SmallVector<::mlir::Block *, 2> {0}Successors;
|
||||
{
|
||||
::mlir::Block *succ;
|
||||
auto firstSucc = parser.parseOptionalSuccessor(succ);
|
||||
|
@ -523,7 +549,6 @@ const char *successorListParserCode = R"(
|
|||
///
|
||||
/// {0}: The name of the successor.
|
||||
const char *successorParserCode = R"(
|
||||
::mlir::Block *{0}Successor = nullptr;
|
||||
if (parser.parseSuccessor({0}Successor))
|
||||
return failure();
|
||||
)";
|
||||
|
@ -595,8 +620,34 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
|
|||
/// Generate the storage code required for parsing the given element.
|
||||
static void genElementParserStorage(Element *element, OpMethodBody &body) {
|
||||
if (auto *optional = dyn_cast<OptionalElement>(element)) {
|
||||
for (auto &childElement : optional->getElements())
|
||||
genElementParserStorage(&childElement, body);
|
||||
auto elements = optional->getElements();
|
||||
|
||||
// If the anchor is a unit attribute, it won't be parsed directly so elide
|
||||
// it.
|
||||
auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
|
||||
Element *elidedAnchorElement = nullptr;
|
||||
if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
|
||||
elidedAnchorElement = anchor;
|
||||
for (auto &childElement : elements)
|
||||
if (&childElement != elidedAnchorElement)
|
||||
genElementParserStorage(&childElement, body);
|
||||
|
||||
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
|
||||
for (auto ¶mElement : custom->getArguments())
|
||||
genElementParserStorage(¶mElement, body);
|
||||
|
||||
} else if (isa<OperandsDirective>(element)) {
|
||||
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
|
||||
"allOperands;\n";
|
||||
|
||||
} else if (isa<SuccessorsDirective>(element)) {
|
||||
body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
|
||||
|
||||
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
|
||||
const NamedAttribute *var = attr->getVar();
|
||||
body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
|
||||
var->name);
|
||||
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
|
||||
StringRef name = operand->getVar()->name;
|
||||
if (operand->getVar()->isVariableLength()) {
|
||||
|
@ -608,10 +659,19 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
|
|||
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
|
||||
<< "Operands(" << name << "RawOperands);";
|
||||
}
|
||||
body << llvm::formatv(
|
||||
" ::llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();\n"
|
||||
" (void){0}OperandsLoc;\n",
|
||||
name);
|
||||
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
|
||||
" (void){0}OperandsLoc;\n",
|
||||
name);
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
||||
StringRef name = successor->getVar()->name;
|
||||
if (successor->getVar()->isVariadic()) {
|
||||
body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
|
||||
"{0}Successors;\n",
|
||||
name);
|
||||
} else {
|
||||
body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
|
||||
}
|
||||
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef name = getTypeListName(dir->getOperand(), lengthKind);
|
||||
|
@ -631,6 +691,106 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Generate the parser for a parameter to a custom directive.
|
||||
static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
|
||||
body << ", ";
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
|
||||
body << attr->getVar()->name << "Attr";
|
||||
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
|
||||
StringRef name = operand->getVar()->name;
|
||||
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
||||
if (lengthKind == ArgumentLengthKind::Variadic)
|
||||
body << llvm::formatv("{0}Operands", name);
|
||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||
body << llvm::formatv("{0}Operand", name);
|
||||
else
|
||||
body << formatv("{0}RawOperands[0]", name);
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
|
||||
StringRef name = successor->getVar()->name;
|
||||
if (successor->getVar()->isVariadic())
|
||||
body << llvm::formatv("{0}Successors", name);
|
||||
else
|
||||
body << llvm::formatv("{0}Successor", name);
|
||||
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::Variadic)
|
||||
body << llvm::formatv("{0}Types", listName);
|
||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||
body << llvm::formatv("{0}Type", listName);
|
||||
else
|
||||
body << formatv("{0}RawTypes[0]", listName);
|
||||
} else {
|
||||
llvm_unreachable("unknown custom directive parameter");
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate the parser for a custom directive.
|
||||
static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
|
||||
body << " {\n";
|
||||
|
||||
// Preprocess the directive variables.
|
||||
// * Add a local variable for optional operands and types. This provides a
|
||||
// better API to the user defined parser methods.
|
||||
// * Set the location of operand variables.
|
||||
for (Element ¶m : dir->getArguments()) {
|
||||
if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
|
||||
body << " " << operand->getVar()->name
|
||||
<< "OperandsLoc = parser.getCurrentLocation();\n";
|
||||
if (operand->getVar()->isOptional()) {
|
||||
body << llvm::formatv(
|
||||
" llvm::Optional<::mlir::OpAsmParser::OperandType> "
|
||||
"{0}Operand;\n",
|
||||
operand->getVar()->name);
|
||||
}
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::Optional)
|
||||
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
|
||||
}
|
||||
}
|
||||
|
||||
body << " if (parse" << dir->getName() << "(parser";
|
||||
for (Element ¶m : dir->getArguments())
|
||||
genCustomParameterParser(param, body);
|
||||
|
||||
body << "))\n"
|
||||
<< " return failure();\n";
|
||||
|
||||
// After parsing, add handling for any of the optional constructs.
|
||||
for (Element ¶m : dir->getArguments()) {
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
|
||||
const NamedAttribute *var = attr->getVar();
|
||||
if (var->attr.isOptional())
|
||||
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
|
||||
|
||||
body << llvm::formatv(
|
||||
" result.attributes.addAttribute(\"{0}\", {0}Attr);", var->name);
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
|
||||
const NamedTypeConstraint *var = operand->getVar();
|
||||
if (!var->isOptional())
|
||||
continue;
|
||||
body << llvm::formatv(" if ({0}Operand.hasValue())\n"
|
||||
" {0}Operands.push_back(*{0}Operand);\n",
|
||||
var->name);
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::Optional) {
|
||||
body << llvm::formatv(" if ({0}Type)\n"
|
||||
" {0}Types.push_back({0}Type);\n",
|
||||
listName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body << " }\n";
|
||||
}
|
||||
|
||||
/// Generate the parser for a single format element.
|
||||
static void genElementParser(Element *element, OpMethodBody &body,
|
||||
FmtContext &attrTypeCtx) {
|
||||
|
@ -711,7 +871,7 @@ static void genElementParser(Element *element, OpMethodBody &body,
|
|||
|
||||
body << formatv(var->attr.isOptional() ? optionalAttrParserCode
|
||||
: attrParserCode,
|
||||
var->attr.getStorageType(), var->name, attrTypeStr);
|
||||
var->name, attrTypeStr);
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
|
||||
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
||||
StringRef name = operand->getVar()->name;
|
||||
|
@ -732,10 +892,11 @@ static void genElementParser(Element *element, OpMethodBody &body,
|
|||
<< (attrDict->isWithKeyword() ? "WithKeyword" : "")
|
||||
<< "(result.attributes))\n"
|
||||
<< " return failure();\n";
|
||||
} else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
|
||||
genCustomDirectiveParser(customDir, body);
|
||||
|
||||
} else if (isa<OperandsDirective>(element)) {
|
||||
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
|
||||
<< " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
|
||||
"allOperands;\n"
|
||||
<< " if (parser.parseOperandList(allOperands))\n"
|
||||
<< " return failure();\n";
|
||||
} else if (isa<SuccessorsDirective>(element)) {
|
||||
|
@ -980,6 +1141,20 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
|
|||
llvm::interleaveComma(op.getOperands(), body, interleaveFn);
|
||||
body << "}));\n";
|
||||
}
|
||||
|
||||
if (!allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments")) {
|
||||
body << " result.addAttribute(\"result_segment_sizes\", "
|
||||
<< "parser.getBuilder().getI32VectorAttr({";
|
||||
auto interleaveFn = [&](const NamedTypeConstraint &result) {
|
||||
// If the result is variadic emit the parsed size.
|
||||
if (result.isVariableLength())
|
||||
body << "static_cast<int32_t>(" << result.name << "Types.size())";
|
||||
else
|
||||
body << "1";
|
||||
};
|
||||
llvm::interleaveComma(op.getResults(), body, interleaveFn);
|
||||
body << "}));\n";
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1007,6 +1182,8 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
|
|||
// Elide the variadic segment size attributes if necessary.
|
||||
if (!fmt.allOperands && op.getTrait("OpTrait::AttrSizedOperandSegments"))
|
||||
body << "\"operand_segment_sizes\", ";
|
||||
if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
|
||||
body << "\"result_segment_sizes\", ";
|
||||
llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
|
||||
body << "\"" << attr->name << "\"";
|
||||
});
|
||||
|
@ -1038,6 +1215,42 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
|
|||
lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
|
||||
}
|
||||
|
||||
/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
|
||||
/// space should be emitted before this element. `lastWasPunctuation` is true if
|
||||
/// the previous element was a punctuation literal.
|
||||
static void genCustomDirectivePrinter(CustomDirective *customDir,
|
||||
OpMethodBody &body) {
|
||||
body << " print" << customDir->getName() << "(p";
|
||||
for (Element ¶m : customDir->getArguments()) {
|
||||
body << ", ";
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
|
||||
body << attr->getVar()->name << "Attr()";
|
||||
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
|
||||
body << operand->getVar()->name << "()";
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
|
||||
body << successor->getVar()->name << "()";
|
||||
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
|
||||
auto *typeOperand = dir->getOperand();
|
||||
auto *operand = dyn_cast<OperandVariable>(typeOperand);
|
||||
auto *var = operand ? operand->getVar()
|
||||
: cast<ResultVariable>(typeOperand)->getVar();
|
||||
if (var->isVariadic())
|
||||
body << var->name << "().getTypes()";
|
||||
else if (var->isOptional())
|
||||
body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
|
||||
else
|
||||
body << var->name << "().getType()";
|
||||
} else {
|
||||
llvm_unreachable("unknown custom directive parameter");
|
||||
}
|
||||
}
|
||||
|
||||
body << ");\n";
|
||||
}
|
||||
|
||||
/// Generate the C++ for an operand to a (*-)type directive.
|
||||
static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
|
||||
if (isa<OperandsDirective>(arg))
|
||||
|
@ -1145,6 +1358,8 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
|
|||
body << " ::llvm::interleaveComma(" << var->name << "(), p);\n";
|
||||
else
|
||||
body << " p << " << var->name << "();\n";
|
||||
} else if (auto *dir = dyn_cast<CustomDirective>(element)) {
|
||||
genCustomDirectivePrinter(dir, body);
|
||||
} else if (isa<OperandsDirective>(element)) {
|
||||
body << " p << getOperation()->getOperands();\n";
|
||||
} else if (isa<SuccessorsDirective>(element)) {
|
||||
|
@ -1202,12 +1417,15 @@ public:
|
|||
caret,
|
||||
comma,
|
||||
equal,
|
||||
less,
|
||||
greater,
|
||||
question,
|
||||
|
||||
// Keywords.
|
||||
keyword_start,
|
||||
kw_attr_dict,
|
||||
kw_attr_dict_w_keyword,
|
||||
kw_custom,
|
||||
kw_functional_type,
|
||||
kw_operands,
|
||||
kw_results,
|
||||
|
@ -1353,6 +1571,10 @@ Token FormatLexer::lexToken() {
|
|||
return formToken(Token::comma, tokStart);
|
||||
case '=':
|
||||
return formToken(Token::equal, tokStart);
|
||||
case '<':
|
||||
return formToken(Token::less, tokStart);
|
||||
case '>':
|
||||
return formToken(Token::greater, tokStart);
|
||||
case '?':
|
||||
return formToken(Token::question, tokStart);
|
||||
case '(':
|
||||
|
@ -1406,6 +1628,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
|
|||
llvm::StringSwitch<Token::Kind>(str)
|
||||
.Case("attr-dict", Token::kw_attr_dict)
|
||||
.Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
|
||||
.Case("custom", Token::kw_custom)
|
||||
.Case("functional-type", Token::kw_functional_type)
|
||||
.Case("operands", Token::kw_operands)
|
||||
.Case("results", Token::kw_results)
|
||||
|
@ -1421,8 +1644,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
|
|||
|
||||
/// Function to find an element within the given range that has the same name as
|
||||
/// 'name'.
|
||||
template <typename RangeT>
|
||||
static auto findArg(RangeT &&range, StringRef name) {
|
||||
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
|
||||
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
|
||||
return it != range.end() ? &*it : nullptr;
|
||||
}
|
||||
|
@ -1513,6 +1735,10 @@ private:
|
|||
LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel,
|
||||
bool withKeyword);
|
||||
LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel);
|
||||
LogicalResult parseCustomDirectiveParameter(
|
||||
std::vector<std::unique_ptr<Element>> ¶meters);
|
||||
LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
|
||||
Token tok, bool isTopLevel);
|
||||
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
|
||||
|
@ -1930,6 +2156,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
|
|||
case Token::kw_attr_dict_w_keyword:
|
||||
return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
|
||||
/*withKeyword=*/true);
|
||||
case Token::kw_custom:
|
||||
return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
|
||||
case Token::kw_functional_type:
|
||||
return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
|
||||
case Token::kw_operands:
|
||||
|
@ -2054,15 +2282,15 @@ LogicalResult FormatParser::parseOptionalChildElement(
|
|||
seenVariables.insert(ele->getVar());
|
||||
return success();
|
||||
})
|
||||
// Literals and type directives may be used, but they can't anchor the
|
||||
// group.
|
||||
.Case<LiteralElement, TypeDirective, FunctionalTypeDirective>(
|
||||
[&](Element *) {
|
||||
if (isAnchor)
|
||||
return emitError(childLoc, "only variables can be used to anchor "
|
||||
"an optional group");
|
||||
return success();
|
||||
})
|
||||
// Literals, custom directives, and type directives may be used,
|
||||
// but they can't anchor the group.
|
||||
.Case<LiteralElement, CustomDirective, TypeDirective,
|
||||
FunctionalTypeDirective>([&](Element *) {
|
||||
if (isAnchor)
|
||||
return emitError(childLoc, "only variables can be used to anchor "
|
||||
"an optional group");
|
||||
return success();
|
||||
})
|
||||
.Default([&](Element *) {
|
||||
return emitError(childLoc, "only literals, types, and variables can be "
|
||||
"used within an optional group");
|
||||
|
@ -2084,6 +2312,71 @@ FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel) {
|
||||
llvm::SMLoc curLoc = curToken.getLoc();
|
||||
|
||||
// Parse the custom directive name.
|
||||
if (failed(
|
||||
parseToken(Token::less, "expected '<' before custom directive name")))
|
||||
return failure();
|
||||
|
||||
Token nameTok = curToken;
|
||||
if (failed(parseToken(Token::identifier,
|
||||
"expected custom directive name identifier")) ||
|
||||
failed(parseToken(Token::greater,
|
||||
"expected '>' after custom directive name")) ||
|
||||
failed(parseToken(Token::l_paren,
|
||||
"expected '(' before custom directive parameters")))
|
||||
return failure();
|
||||
|
||||
// Parse the child elements for this optional group.=
|
||||
std::vector<std::unique_ptr<Element>> elements;
|
||||
do {
|
||||
if (failed(parseCustomDirectiveParameter(elements)))
|
||||
return failure();
|
||||
if (curToken.getKind() != Token::comma)
|
||||
break;
|
||||
consumeToken();
|
||||
} while (true);
|
||||
|
||||
if (failed(parseToken(Token::r_paren,
|
||||
"expected ')' after custom directive parameters")))
|
||||
return failure();
|
||||
|
||||
// After parsing all of the elements, ensure that all type directives refer
|
||||
// only to variables.
|
||||
for (auto &ele : elements) {
|
||||
if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
|
||||
if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
|
||||
return emitError(curLoc, "type directives within a custom directive "
|
||||
"may only refer to variables");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
|
||||
std::move(elements));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::parseCustomDirectiveParameter(
|
||||
std::vector<std::unique_ptr<Element>> ¶meters) {
|
||||
llvm::SMLoc childLoc = curToken.getLoc();
|
||||
parameters.push_back({});
|
||||
if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
|
||||
return failure();
|
||||
|
||||
// Verify that the element can be placed within a custom directive.
|
||||
if (!isa<TypeDirective, AttributeVariable, OperandVariable,
|
||||
SuccessorVariable>(parameters.back().get())) {
|
||||
return emitError(childLoc, "only variables and types may be used as "
|
||||
"parameters to a custom directive");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
|
||||
Token tok, bool isTopLevel) {
|
||||
|
|
Loading…
Reference in New Issue