[mlir][DeclarativeParser] Add basic support for optional groups in the assembly format.

When operations have optional attributes, or optional operands(i.e. empty variadic operands), the assembly format often has an optional section to represent these arguments. This revision adds basic support for defining an "optional group" in the assembly format to support this. An optional group is defined by wrapping a set of elements in `()` followed by `?` and requires the following:

* The first element of the group must be either a literal or an operand argument.
  - This is because the first element must be optionally parsable.
* There must be exactly one argument variable within the group that is marked as the anchor of the group. The anchor is the element whose presence controls whether the group should be printed/parsed. An element is marked as the anchor by adding a trailing `^`.
* The group must only contain literals, variables, and type directives.
  - Any attribute variables may be used, but only optional attributes can be marked as the anchor.
  - Only variadic, i.e. optional, operand arguments can be used.
  - The elements of a type directive must be defined within the same optional group.

An example of this can be seen with the assembly format for ReturnOp, which has a variadic number of operands.

```
def ReturnOp : ... {
  let arguments = (ins Variadic<AnyType>:$operands);

  // We only print the operands+types if there are a non-zero number
  // of operands.
  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
```

Differential Revision: https://reviews.llvm.org/D74681
This commit is contained in:
River Riddle 2020-02-21 13:19:15 -08:00
parent 26222db01b
commit 2d0477a003
5 changed files with 473 additions and 169 deletions

View File

@ -619,6 +619,43 @@ the variables would be `$callee` and `$args`.
Attribute variables are printed with their respective value type, unless that Attribute variables are printed with their respective value type, unless that
value type is buildable. In those cases, the type of the attribute is elided. value type is buildable. In those cases, the type of the attribute is elided.
#### Optional Groups
In certain situations operations may have "optional" information, e.g.
attributes or an empty set of variadic operands. In these situtations a section
of the assembly format can be marked as `optional` based on the presence of this
information. An optional group is defined by wrapping a set of elements within
`()` followed by a `?` and has the following requirements:
* The first element of the group must either be a literal or an operand.
- This is because the first element must be optionally parsable.
* Exactly one argument variable within the group must be marked as the anchor
of the group.
- The anchor is the element whose presence controls whether the group
should be printed/parsed.
- An element is marked as the anchor by adding a trailing `^`.
- The first element is *not* required to be the anchor of the group.
* Literals, variables, and type directives are the only valid elements within
the group.
- Any attribute variable may be used, but only optional attributes can be
marked as the anchor.
- Only variadic, i.e. optional, operand arguments can be used.
- The operands to a type directive must be defined within the optional
group.
An example of an operation with an optional group is `std.return`, which has a
variadic number of operands.
```
def ReturnOp : ... {
let arguments = (ins Variadic<AnyType>:$operands);
// We only print the operands and types if there are a non-zero number
// of operands.
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
```
#### Requirements #### Requirements
The format specification has a certain set of requirements that must be adhered The format specification has a certain set of requirements that must be adhered

View File

@ -1059,6 +1059,8 @@ def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
let builders = [OpBuilder< let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
>]; >];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
} }
def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,

View File

@ -1736,21 +1736,6 @@ OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
// ReturnOp // ReturnOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(opInfo) ||
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(opInfo, types, loc, result.operands));
}
static void print(OpAsmPrinter &p, ReturnOp op) {
p << "return";
if (op.getNumOperands() != 0)
p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
}
static LogicalResult verify(ReturnOp op) { static LogicalResult verify(ReturnOp op) {
auto function = cast<FuncOp>(op.getParentOp()); auto function = cast<FuncOp>(op.getParentOp());

View File

@ -46,7 +46,7 @@ def DirectiveFunctionalTypeInvalidA : TestFormat_Op<"functype_invalid_a", [{
def DirectiveFunctionalTypeInvalidB : TestFormat_Op<"functype_invalid_b", [{ def DirectiveFunctionalTypeInvalidB : TestFormat_Op<"functype_invalid_b", [{
functional-type functional-type
}]>; }]>;
// CHECK: error: expected directive, literal, or variable // CHECK: error: expected directive, literal, variable, or optional group
def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{ def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{
functional-type( functional-type(
}]>; }]>;
@ -54,7 +54,7 @@ def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{
def DirectiveFunctionalTypeInvalidD : TestFormat_Op<"functype_invalid_d", [{ def DirectiveFunctionalTypeInvalidD : TestFormat_Op<"functype_invalid_d", [{
functional-type(operands functional-type(operands
}]>; }]>;
// CHECK: error: expected directive, literal, or variable // CHECK: error: expected directive, literal, variable, or optional group
def DirectiveFunctionalTypeInvalidE : TestFormat_Op<"functype_invalid_e", [{ def DirectiveFunctionalTypeInvalidE : TestFormat_Op<"functype_invalid_e", [{
functional-type(operands, functional-type(operands,
}]>; }]>;
@ -98,7 +98,7 @@ def DirectiveResultsInvalidA : TestFormat_Op<"operands_invalid_a", [{
def DirectiveTypeInvalidA : TestFormat_Op<"type_invalid_a", [{ def DirectiveTypeInvalidA : TestFormat_Op<"type_invalid_a", [{
type type
}]>; }]>;
// CHECK: error: expected directive, literal, or variable // CHECK: error: expected directive, literal, variable, or optional group
def DirectiveTypeInvalidB : TestFormat_Op<"type_invalid_b", [{ def DirectiveTypeInvalidB : TestFormat_Op<"type_invalid_b", [{
type( type(
}]>; }]>;
@ -165,7 +165,7 @@ def LiteralInvalidA : TestFormat_Op<"literal_invalid_a", [{
`1` `1`
}]>; }]>;
// CHECK: error: unexpected end of file in literal // CHECK: error: unexpected end of file in literal
// CHECK: error: expected directive, literal, or variable // CHECK: error: expected directive, literal, variable, or optional group
def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{ def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{
` `
}]>; }]>;
@ -175,6 +175,55 @@ def LiteralValid : TestFormat_Op<"literal_valid", [{
attr-dict attr-dict
}]>; }]>;
//===----------------------------------------------------------------------===//
// Optional Groups
//===----------------------------------------------------------------------===//
// CHECK: error: optional groups can only be used as top-level elements
def OptionalInvalidA : TestFormat_Op<"optional_invalid_a", [{
type(($attr^)?) attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
// CHECK: error: expected directive, literal, variable, or optional group
def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
() attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
// CHECK: error: optional group specified no anchor element
def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
($attr)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
// CHECK: error: first element of an operand group must be a literal or operand
def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
($attr^)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
// CHECK: error: type directive can only refer to variables within the optional group
def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{
(`,` $attr^ type(operands))? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
// CHECK: error: only one element can be marked as the anchor of an optional group
def OptionalInvalidF : TestFormat_Op<"optional_invalid_f", [{
($attr^ $attr2^) attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr, OptionalAttr<I64Attr>:$attr2)>;
// CHECK: error: only optional attributes can be used to anchor an optional group
def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{
($attr^) attr-dict
}]>, Arguments<(ins I64Attr:$attr)>;
// CHECK: error: only variadic operands can be used within an optional group
def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{
($arg^) attr-dict
}]>, Arguments<(ins I64:$arg)>;
// CHECK: error: only variables can be used to anchor an optional group
def OptionalInvalidI : TestFormat_Op<"optional_invalid_i", [{
($arg type($arg)^) attr-dict
}]>, Arguments<(ins Variadic<I64>:$arg)>;
// CHECK: error: only literals, types, and variables can be used within an optional group
def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
(attr-dict)
}]>;
// CHECK: error: expected '?' after optional group
def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
($arg^)
}]>, Arguments<(ins Variadic<I64>:$arg)>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Variables // Variables
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -58,6 +58,9 @@ public:
AttributeVariable, AttributeVariable,
OperandVariable, OperandVariable,
ResultVariable, ResultVariable,
/// This element is an optional element.
Optional,
}; };
Element(Kind kind) : kind(kind) {} Element(Kind kind) : kind(kind) {}
virtual ~Element() = default; virtual ~Element() = default;
@ -164,7 +167,7 @@ namespace {
class LiteralElement : public Element { class LiteralElement : public Element {
public: public:
LiteralElement(StringRef literal) LiteralElement(StringRef literal)
: Element{Kind::Literal}, literal(literal){}; : Element{Kind::Literal}, literal(literal) {}
static bool classof(const Element *element) { static bool classof(const Element *element) {
return element->getKind() == Kind::Literal; return element->getKind() == Kind::Literal;
} }
@ -203,6 +206,36 @@ bool LiteralElement::isValidLiteral(StringRef value) {
}); });
} }
//===----------------------------------------------------------------------===//
// OptionalElement
namespace {
/// This class represents a group of elements that are optionally emitted based
/// upon an optional variable of the operation.
class OptionalElement : public Element {
public:
OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
unsigned anchor)
: Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor) {
}
static bool classof(const Element *element) {
return element->getKind() == Kind::Optional;
}
/// Return the nested elements of this grouping.
auto getElements() const { return llvm::make_pointee_range(elements); }
/// Return the anchor of this optional group.
Element *getAnchor() const { return elements[anchor].get(); }
private:
/// The child elements of this optional.
std::vector<std::unique_ptr<Element>> elements;
/// The index of the element that acts as the anchor for the optional group.
unsigned anchor;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// OperationFormat // OperationFormat
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -327,32 +360,26 @@ const char *const enumAttrParserCode = R"(
const char *const variadicOperandParserCode = R"( const char *const variadicOperandParserCode = R"(
llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation(); llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();
(void){0}OperandsLoc; (void){0}OperandsLoc;
SmallVector<OpAsmParser::OperandType, 4> {0}Operands;
if (parser.parseOperandList({0}Operands)) if (parser.parseOperandList({0}Operands))
return failure(); return failure();
)"; )";
const char *const operandParserCode = R"( const char *const operandParserCode = R"(
llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation(); llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();
(void){0}OperandsLoc; (void){0}OperandsLoc;
OpAsmParser::OperandType {0}RawOperands[1];
if (parser.parseOperand({0}RawOperands[0])) if (parser.parseOperand({0}RawOperands[0]))
return failure(); return failure();
ArrayRef<OpAsmParser::OperandType> {0}Operands({0}RawOperands);
)"; )";
/// The code snippet used to generate a parser call for a type list. /// The code snippet used to generate a parser call for a type list.
/// ///
/// {0}: The name for the type list. /// {0}: The name for the type list.
const char *const variadicTypeParserCode = R"( const char *const variadicTypeParserCode = R"(
SmallVector<Type, 1> {0}Types;
if (parser.parseTypeList({0}Types)) if (parser.parseTypeList({0}Types))
return failure(); return failure();
)"; )";
const char *const typeParserCode = R"( const char *const typeParserCode = R"(
Type {0}RawTypes[1] = {{nullptr};
if (parser.parseType({0}RawTypes[0])) if (parser.parseType({0}RawTypes[0]))
return failure(); return failure();
ArrayRef<Type> {0}Types({0}RawTypes);
)"; )";
/// The code snippet used to generate a parser call for a functional type. /// The code snippet used to generate a parser call for a functional type.
@ -363,8 +390,8 @@ const char *const functionalTypeParserCode = R"(
FunctionType {0}__{1}_functionType; FunctionType {0}__{1}_functionType;
if (parser.parseType({0}__{1}_functionType)) if (parser.parseType({0}__{1}_functionType))
return failure(); return failure();
ArrayRef<Type> {0}Types = {0}__{1}_functionType.getInputs(); {0}Types = {0}__{1}_functionType.getInputs();
ArrayRef<Type> {1}Types = {0}__{1}_functionType.getResults(); {1}Types = {0}__{1}_functionType.getResults();
)"; )";
/// Get the name used for the type list for the given type directive operand. /// Get the name used for the type list for the given type directive operand.
@ -388,12 +415,11 @@ static StringRef getTypeListName(Element *arg, bool &isVariadic) {
/// Generate the parser for a literal value. /// Generate the parser for a literal value.
static void genLiteralParser(StringRef value, OpMethodBody &body) { static void genLiteralParser(StringRef value, OpMethodBody &body) {
body << " if (parser.parse";
// Handle the case of a keyword/identifier. // Handle the case of a keyword/identifier.
if (value.front() == '_' || isalpha(value.front())) { if (value.front() == '_' || isalpha(value.front())) {
body << "Keyword(\"" << value << "\")"; body << "Keyword(\"" << value << "\")";
} else { return;
}
body << (StringRef)llvm::StringSwitch<StringRef>(value) body << (StringRef)llvm::StringSwitch<StringRef>(value)
.Case("->", "Arrow()") .Case("->", "Arrow()")
.Case(":", "Colon()") .Case(":", "Colon()")
@ -406,27 +432,69 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
.Case("[", "LSquare()") .Case("[", "LSquare()")
.Case("]", "RSquare()"); .Case("]", "RSquare()");
} }
body << ")\n return failure();\n";
/// Generate the storage code required for parsing the given element.
static void genElementParserStorage(Element *element, OpMethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
for (auto &childElement : optional->getElements())
genElementParserStorage(&childElement, body);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name;
if (operand->getVar()->isVariadic())
body << " SmallVector<OpAsmParser::OperandType, 4> " << name
<< "Operands;\n";
else
body << " OpAsmParser::OperandType " << name << "RawOperands[1];\n"
<< " ArrayRef<OpAsmParser::OperandType> " << name << "Operands("
<< name << "RawOperands);";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
bool variadic = false;
StringRef name = getTypeListName(dir->getOperand(), variadic);
if (variadic)
body << " SmallVector<Type, 1> " << name << "Types;\n";
else
body << llvm::formatv(" Type {0}RawTypes[1];\n", name)
<< llvm::formatv(" ArrayRef<Type> {0}Types({0}RawTypes);\n", name);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
bool ignored = false;
body << " ArrayRef<Type> " << getTypeListName(dir->getInputs(), ignored)
<< "Types;\n";
body << " ArrayRef<Type> " << getTypeListName(dir->getResults(), ignored)
<< "Types;\n";
}
} }
void OperationFormat::genParser(Operator &op, OpClass &opClass) { /// Generate the parser for a single format element.
auto &method = opClass.newMethod( static void genElementParser(Element *element, OpMethodBody &body,
"ParseResult", "parse", "OpAsmParser &parser, OperationState &result", FmtContext &attrTypeCtx) {
OpMethod::MP_Static); /// Optional Group.
auto &body = method.body(); if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements = optional->getElements();
// A format context used when parsing attributes with buildable types. // Generate a special optional parser for the first element to gate the
FmtContext attrTypeCtx; // parsing of the rest of the elements.
attrTypeCtx.withBuilder("parser.getBuilder()"); if (auto *literal = dyn_cast<LiteralElement>(&*elements.begin())) {
body << " if (succeeded(parser.parseOptional";
// Generate parsers for each of the elements.
for (auto &element : elements) {
/// Literals.
if (LiteralElement *literal = dyn_cast<LiteralElement>(element.get())) {
genLiteralParser(literal->getLiteral(), body); genLiteralParser(literal->getLiteral(), body);
body << ")) {\n";
} else if (auto *opVar = dyn_cast<OperandVariable>(&*elements.begin())) {
genElementParser(opVar, body, attrTypeCtx);
body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
}
// Generate the rest of the elements normally.
for (auto &childElement : llvm::drop_begin(elements, 1))
genElementParser(&childElement, body, attrTypeCtx);
body << " }\n";
/// Literals.
} else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
body << " if (parser.parse";
genLiteralParser(literal->getLiteral(), body);
body << ")\n return failure();\n";
/// Arguments. /// Arguments.
} else if (auto *attr = dyn_cast<AttributeVariable>(element.get())) { } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
const NamedAttribute *var = attr->getVar(); const NamedAttribute *var = attr->getVar();
// Check to see if we can parse this as an enum attribute. // Check to see if we can parse this as an enum attribute.
@ -441,10 +509,9 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
"attrOptional.getValue()"); "attrOptional.getValue()");
} }
body << formatv(enumAttrParserCode, var->name, body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
enumAttr.getCppNamespace(),
enumAttr.getStringToSymbolFnName(), attrBuilderStr); enumAttr.getStringToSymbolFnName(), attrBuilderStr);
continue; return;
} }
// If this attribute has a buildable type, use that when parsing the // If this attribute has a buildable type, use that when parsing the
@ -459,27 +526,26 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
body << formatv(attrParserCode, var->attr.getStorageType(), var->name, body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
attrTypeStr); attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element.get())) { } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
bool isVariadic = operand->getVar()->isVariadic(); bool isVariadic = operand->getVar()->isVariadic();
body << formatv(isVariadic ? variadicOperandParserCode body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode,
: operandParserCode,
operand->getVar()->name); operand->getVar()->name);
/// Directives. /// Directives.
} else if (isa<AttrDictDirective>(element.get())) { } else if (isa<AttrDictDirective>(element)) {
body << " if (parser.parseOptionalAttrDict(result.attributes))\n" body << " if (parser.parseOptionalAttrDict(result.attributes))\n"
<< " return failure();\n"; << " return failure();\n";
} else if (isa<OperandsDirective>(element.get())) { } else if (isa<OperandsDirective>(element)) {
body << " llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n" body << " llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
<< " SmallVector<OpAsmParser::OperandType, 4> allOperands;\n" << " SmallVector<OpAsmParser::OperandType, 4> allOperands;\n"
<< " if (parser.parseOperandList(allOperands))\n" << " if (parser.parseOperandList(allOperands))\n"
<< " return failure();\n"; << " return failure();\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element.get())) { } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
bool isVariadic = false; bool isVariadic = false;
StringRef listName = getTypeListName(dir->getOperand(), isVariadic); StringRef listName = getTypeListName(dir->getOperand(), isVariadic);
body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode, body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode,
listName); listName);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element.get())) { } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
bool ignored = false; bool ignored = false;
body << formatv(functionalTypeParserCode, body << formatv(functionalTypeParserCode,
getTypeListName(dir->getInputs(), ignored), getTypeListName(dir->getInputs(), ignored),
@ -489,6 +555,26 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
} }
} }
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
auto &method = opClass.newMethod(
"ParseResult", "parse", "OpAsmParser &parser, OperationState &result",
OpMethod::MP_Static);
auto &body = method.body();
// Generate variables to store the operands and type within the format. This
// allows for referencing these variables in the presence of optional
// groupings.
for (auto &element : elements)
genElementParserStorage(&*element, body);
// A format context used when parsing attributes with buildable types.
FmtContext attrTypeCtx;
attrTypeCtx.withBuilder("parser.getBuilder()");
// Generate parsers for each of the elements.
for (auto &element : elements)
genElementParser(element.get(), body, attrTypeCtx);
// Generate the code to resolve the operand and result types now that they // Generate the code to resolve the operand and result types now that they
// have been parsed. // have been parsed.
genParserTypeResolution(op, body); genParserTypeResolution(op, body);
@ -676,7 +762,7 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
} }
/// Generate the c++ for an operand to a (*-)type directive. /// Generate the C++ for an operand to a (*-)type directive.
static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) { static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
if (isa<OperandsDirective>(arg)) if (isa<OperandsDirective>(arg))
return body << "getOperation()->getOperandTypes()"; return body << "getOperation()->getOperandTypes()";
@ -689,6 +775,79 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
return body << "ArrayRef<Type>(" << var->name << "().getType())"; return body << "ArrayRef<Type>(" << var->name << "().getType())";
} }
/// Generate the code for printing the given element.
static void genElementPrinter(Element *element, OpMethodBody &body,
OperationFormat &fmt, bool &shouldEmitSpace,
bool &lastWasPunctuation) {
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
lastWasPunctuation);
// Emit an optional group.
if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
// Emit the check for the presence of the anchor element.
Element *anchor = optional->getAnchor();
if (AttributeVariable *attrVar = dyn_cast<AttributeVariable>(anchor))
body << " if (getAttr(\"" << attrVar->getVar()->name << "\")) {\n";
else
body << " if (!" << cast<OperandVariable>(anchor)->getVar()->name
<< "().empty()) {\n";
// Emit each of the elements.
for (Element &childElement : optional->getElements())
genElementPrinter(&childElement, body, fmt, shouldEmitSpace,
lastWasPunctuation);
body << " }\n";
return;
}
// Emit the attribute dictionary.
if (isa<AttrDictDirective>(element)) {
genAttrDictPrinter(fmt, body);
lastWasPunctuation = false;
return;
}
// Optionally insert a space before the next element. The AttrDict printer
// already adds a space as necessary.
if (shouldEmitSpace || !lastWasPunctuation)
body << " p << \" \";\n";
lastWasPunctuation = false;
shouldEmitSpace = true;
if (auto *attr = dyn_cast<AttributeVariable>(element)) {
const NamedAttribute *var = attr->getVar();
// If we are formatting as a enum, symbolize the attribute as a string.
if (canFormatEnumAttr(var)) {
const EnumAttr &enumAttr = cast<EnumAttr>(var->attr);
body << " p << \"\\\"\" << " << enumAttr.getSymbolToStringFnName() << "("
<< var->name << "()) << \"\\\"\";\n";
return;
}
// Elide the attribute type if it is buildable.
Optional<Type> attrType = var->attr.getValueType();
if (attrType && attrType->getBuilderCall())
body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";
else
body << " p.printAttribute(" << var->name << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
body << " p << " << operand->getVar()->name << "();\n";
} else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
body << " p << ";
genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
body << " p.printFunctionalType(";
genTypeOperandPrinter(dir->getInputs(), body) << ", ";
genTypeOperandPrinter(dir->getResults(), body) << ");\n";
} else {
llvm_unreachable("unknown format element");
}
}
void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p"); auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p");
auto &body = method.body(); auto &body = method.body();
@ -706,60 +865,9 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
// Flags for if we should emit a space, and if the last element was // Flags for if we should emit a space, and if the last element was
// punctuation. // punctuation.
bool shouldEmitSpace = true, lastWasPunctuation = false; bool shouldEmitSpace = true, lastWasPunctuation = false;
for (auto &element : elements) { for (auto &element : elements)
// Emit a literal element. genElementPrinter(element.get(), body, *this, shouldEmitSpace,
if (LiteralElement *literal = dyn_cast<LiteralElement>(element.get())) {
genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
lastWasPunctuation); lastWasPunctuation);
continue;
}
// Emit the attribute dictionary.
if (isa<AttrDictDirective>(element.get())) {
genAttrDictPrinter(*this, body);
lastWasPunctuation = false;
continue;
}
// Optionally insert a space before the next element. The AttrDict printer
// already adds a space as necessary.
if (shouldEmitSpace || !lastWasPunctuation)
body << " p << \" \";\n";
lastWasPunctuation = false;
shouldEmitSpace = true;
if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
const NamedAttribute *var = attr->getVar();
// If we are formatting as a enum, symbolize the attribute as a string.
if (canFormatEnumAttr(var)) {
const EnumAttr &enumAttr = cast<EnumAttr>(var->attr);
body << " p << \"\\\"\" << " << enumAttr.getSymbolToStringFnName()
<< "(" << var->name << "()) << \"\\\"\";\n";
continue;
}
// Elide the attribute type if it is buildable.
Optional<Type> attrType = var->attr.getValueType();
if (attrType && attrType->getBuilderCall())
body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";
else
body << " p.printAttribute(" << var->name << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element.get())) {
body << " p << " << operand->getVar()->name << "();\n";
} else if (isa<OperandsDirective>(element.get())) {
body << " p << getOperation()->getOperands();\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element.get())) {
body << " p << ";
genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element.get())) {
body << " p.printFunctionalType(";
genTypeOperandPrinter(dir->getInputs(), body) << ", ";
genTypeOperandPrinter(dir->getResults(), body) << ");\n";
} else {
llvm_unreachable("unknown format element");
}
}
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -778,8 +886,10 @@ public:
// Tokens with no info. // Tokens with no info.
l_paren, l_paren,
r_paren, r_paren,
caret,
comma, comma,
equal, equal,
question,
// Keywords. // Keywords.
keyword_start, keyword_start,
@ -908,10 +1018,14 @@ Token FormatLexer::lexToken() {
return formToken(Token::eof, tokStart); return formToken(Token::eof, tokStart);
// Lex punctuation. // Lex punctuation.
case '^':
return formToken(Token::caret, tokStart);
case ',': case ',':
return formToken(Token::comma, tokStart); return formToken(Token::comma, tokStart);
case '=': case '=':
return formToken(Token::equal, tokStart); return formToken(Token::equal, tokStart);
case '?':
return formToken(Token::question, tokStart);
case '(': case '(':
return formToken(Token::l_paren, tokStart); return formToken(Token::l_paren, tokStart);
case ')': case ')':
@ -1026,6 +1140,12 @@ private:
LogicalResult parseDirective(std::unique_ptr<Element> &element, LogicalResult parseDirective(std::unique_ptr<Element> &element,
bool isTopLevel); bool isTopLevel);
LogicalResult parseLiteral(std::unique_ptr<Element> &element); LogicalResult parseLiteral(std::unique_ptr<Element> &element);
LogicalResult parseOptional(std::unique_ptr<Element> &element,
bool isTopLevel);
LogicalResult parseOptionalChildElement(
std::vector<std::unique_ptr<Element>> &childElements,
SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
Optional<unsigned> &anchorIdx);
/// Parse the various different directives. /// Parse the various different directives.
LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element, LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
@ -1077,6 +1197,7 @@ private:
llvm::SmallBitVector seenOperandTypes, seenResultTypes; llvm::SmallBitVector seenOperandTypes, seenResultTypes;
llvm::DenseSet<const NamedTypeConstraint *> seenOperands; llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
llvm::DenseSet<const NamedAttribute *> seenAttrs; llvm::DenseSet<const NamedAttribute *> seenAttrs;
llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
}; };
} // end anonymous namespace } // end anonymous namespace
@ -1236,11 +1357,14 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
// Literals. // Literals.
if (curToken.getKind() == Token::literal) if (curToken.getKind() == Token::literal)
return parseLiteral(element); return parseLiteral(element);
// Optionals.
if (curToken.getKind() == Token::l_paren)
return parseOptional(element, isTopLevel);
// Variables. // Variables.
if (curToken.getKind() == Token::variable) if (curToken.getKind() == Token::variable)
return parseVariable(element, isTopLevel); return parseVariable(element, isTopLevel);
return emitError(curToken.getLoc(), return emitError(curToken.getLoc(),
"expected directive, literal, or variable"); "expected directive, literal, variable, or optional group");
} }
LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element, LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
@ -1314,6 +1438,115 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
return success(); return success();
} }
LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
bool isTopLevel) {
llvm::SMLoc curLoc = curToken.getLoc();
if (!isTopLevel)
return emitError(curLoc, "optional groups can only be used as top-level "
"elements");
consumeToken();
// Parse the child elements for this optional group.
std::vector<std::unique_ptr<Element>> elements;
SmallPtrSet<const NamedTypeConstraint *, 8> seenVariables;
Optional<unsigned> anchorIdx;
do {
if (failed(parseOptionalChildElement(elements, seenVariables, anchorIdx)))
return failure();
} while (curToken.getKind() != Token::r_paren);
consumeToken();
if (failed(parseToken(Token::question, "expected '?' after optional group")))
return failure();
// The optional group is required to have an anchor.
if (!anchorIdx)
return emitError(curLoc, "optional group specified no anchor element");
// The first element of the group must be one that can be parsed/printed in an
// optional fashion.
if (!isa<LiteralElement>(&*elements.front()) &&
!isa<OperandVariable>(&*elements.front()))
return emitError(curLoc, "first element of an operand group must be a "
"literal or operand");
// After parsing all of the elements, ensure that all type directives refer
// only to elements within the group.
auto checkTypeOperand = [&](Element *typeEle) {
auto *opVar = dyn_cast<OperandVariable>(typeEle);
const NamedTypeConstraint *var = opVar ? opVar->getVar() : nullptr;
if (!seenVariables.count(var))
return emitError(curLoc, "type directive can only refer to variables "
"within the optional group");
return success();
};
for (auto &ele : elements) {
if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
if (failed(checkTypeOperand(typeEle->getOperand())))
return failure();
} else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get())) {
if (failed(checkTypeOperand(typeEle->getInputs())) ||
failed(checkTypeOperand(typeEle->getResults())))
return failure();
}
}
optionalVariables.insert(seenVariables.begin(), seenVariables.end());
element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx);
return success();
}
LogicalResult FormatParser::parseOptionalChildElement(
std::vector<std::unique_ptr<Element>> &childElements,
SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
Optional<unsigned> &anchorIdx) {
llvm::SMLoc childLoc = curToken.getLoc();
childElements.push_back({});
if (failed(parseElement(childElements.back(), /*isTopLevel=*/true)))
return failure();
// Check to see if this element is the anchor of the optional group.
bool isAnchor = curToken.getKind() == Token::caret;
if (isAnchor) {
if (anchorIdx)
return emitError(childLoc, "only one element can be marked as the anchor "
"of an optional group");
anchorIdx = childElements.size() - 1;
consumeToken();
}
return TypeSwitch<Element *, LogicalResult>(childElements.back().get())
// All attributes can be within the optional group, but only optional
// attributes can be the anchor.
.Case([&](AttributeVariable *attrEle) {
if (isAnchor && !attrEle->getVar()->attr.isOptional())
return emitError(childLoc, "only optional attributes can be used to "
"anchor an optional group");
return success();
})
// Only optional-like(i.e. variadic) operands can be within an optional
// group.
.Case<OperandVariable>([&](auto *ele) {
if (!ele->getVar()->isVariadic())
return emitError(childLoc, "only variadic operands can be used within"
" an optional group");
seenVariables.insert(ele->getVar());
return success();
})
// Literals and type directives may be used, but they can't anchor the
// group.
.Case<LiteralElement, TypeDirective, FunctionalTypeDirective>(
[&](auto *) {
if (isAnchor)
return emitError(childLoc, "only variables can be used to anchor "
"an optional group");
return success();
})
.Default([&](auto *) {
return emitError(childLoc, "only literals, types, and variables can be "
"used within an optional group");
});
}
LogicalResult LogicalResult
FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element, FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel) { llvm::SMLoc loc, bool isTopLevel) {
@ -1344,8 +1577,6 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
failed(parseTypeDirectiveOperand(results)) || failed(parseTypeDirectiveOperand(results)) ||
failed(parseToken(Token::r_paren, "expected ')' after argument list"))) failed(parseToken(Token::r_paren, "expected ')' after argument list")))
return failure(); return failure();
// Get the proper directive kind and create it.
element = std::make_unique<FunctionalTypeDirective>(std::move(inputs), element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
std::move(results)); std::move(results));
return success(); return success();