[mlir][DeclarativeParser] Add support for formatting the successors of an operation.

This revision add support for formatting successor variables in a similar way to operands, attributes, etc.

Differential Revision: https://reviews.llvm.org/D74789
This commit is contained in:
River Riddle 2020-02-21 13:20:06 -08:00
parent b1de971ba8
commit 9eb436feaa
15 changed files with 228 additions and 149 deletions

View File

@ -625,6 +625,10 @@ The available directives are as follows:
- Represents all of the results of an operation. - Represents all of the results of an operation.
* `successors`
- Represents all of the successors of an operation.
* `type` ( input ) * `type` ( input )
- Represents the type of the given input. - Represents the type of the given input.
@ -641,8 +645,8 @@ The following are the set of valid punctuation:
#### Variables #### Variables
A variable is an entity that has been registered on the operation itself, i.e. A variable is an entity that has been registered on the operation itself, i.e.
an argument(attribute or operand), result, etc. In the `CallOp` example above, an argument(attribute or operand), result, successor, etc. In the `CallOp`
the variables would be `$callee` and `$args`. example above, the variables would be `$callee` and `$args`.
Attribute variables are printed with their respective value type, unless that Attribute variables are printed with their respective value type, unless that
value type is buildable. In those cases, the type of the attribute is elided. value type is buildable. In those cases, the type of the attribute is elided.

View File

@ -455,15 +455,12 @@ def LLVM_SelectOp
// Terminators. // Terminators.
def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
let successors = (successor AnySuccessor:$dest); let successors = (successor AnySuccessor:$dest);
let parser = [{ return parseBrOp(parser, result); }]; let assemblyFormat = "$dest attr-dict";
let printer = [{ printBrOp(p, *this); }];
} }
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> { def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
let arguments = (ins LLVMI1:$condition); let arguments = (ins LLVMI1:$condition);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let assemblyFormat = "$condition `,` successors attr-dict";
let parser = [{ return parseCondBrOp(parser, result); }];
let printer = [{ printCondBrOp(p, *this); }];
} }
def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>, def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>,
Arguments<(ins Variadic<LLVM_Type>:$args)> { Arguments<(ins Variadic<LLVM_Type>:$args)> {

View File

@ -69,6 +69,8 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
}]; }];
let autogenSerialization = 0; let autogenSerialization = 0;
let assemblyFormat = "successors attr-dict";
} }
// ----- // -----

View File

@ -250,6 +250,7 @@ def BranchOp : Std_Op<"br", [Terminator]> {
}]; }];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let assemblyFormat = "$dest attr-dict";
} }
def CallOp : Std_Op<"call", [CallOpInterface]> { def CallOp : Std_Op<"call", [CallOpInterface]> {
@ -602,6 +603,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
}]; }];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let assemblyFormat = "$condition `,` successors attr-dict";
} }
def ConstantOp : Std_Op<"constant", def ConstantOp : Std_Op<"constant",

View File

@ -578,6 +578,11 @@ public:
virtual ParseResult virtual ParseResult
parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
/// Parse an optional operation successor and its operand list.
virtual OptionalParseResult
parseOptionalSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value> &operands) = 0;
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Type Parsing // Type Parsing
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//

View File

@ -780,69 +780,6 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::BrOp.
//===----------------------------------------------------------------------===//
static void printBrOp(OpAsmPrinter &p, BrOp &op) {
p << op.getOperationName() << ' ';
p.printSuccessorAndUseList(op.getOperation(), 0);
p.printOptionalAttrDict(op.getAttrs());
}
// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
// attribute-dict?
static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) {
Block *dest;
SmallVector<Value, 4> operands;
if (parser.parseSuccessorAndUseList(dest, operands) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
result.addSuccessor(dest, operands);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::CondBrOp.
//===----------------------------------------------------------------------===//
static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) {
p << op.getOperationName() << ' ' << op.getOperand(0) << ", ";
p.printSuccessorAndUseList(op.getOperation(), 0);
p << ", ";
p.printSuccessorAndUseList(op.getOperation(), 1);
p.printOptionalAttrDict(op.getAttrs());
}
// <operation> ::= `llvm.cond_br` ssa-use `,`
// bb-id (`[` ssa-use-and-type-list `]`)? `,`
// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict?
static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) {
Block *trueDest;
Block *falseDest;
SmallVector<Value, 4> trueOperands;
SmallVector<Value, 4> falseOperands;
OpAsmParser::OperandType condition;
Builder &builder = parser.getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
if (parser.parseOperand(condition) || parser.parseComma() ||
parser.parseSuccessorAndUseList(trueDest, trueOperands) ||
parser.parseComma() ||
parser.parseSuccessorAndUseList(falseDest, falseOperands) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.resolveOperand(condition, i1Type, result.operands))
return failure();
result.addSuccessor(trueDest, trueOperands);
result.addSuccessor(falseDest, falseOperands);
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ReturnOp. // Printing/parsing for LLVM::ReturnOp.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1018,24 +1018,6 @@ void spirv::BitcastOp::getCanonicalizationPatterns(
results.insert<ConvertChainedBitcast>(context); results.insert<ConvertChainedBitcast>(context);
} }
//===----------------------------------------------------------------------===//
// spv.BranchOp
//===----------------------------------------------------------------------===//
static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) {
Block *dest;
SmallVector<Value, 4> destOperands;
if (parser.parseSuccessorAndUseList(dest, destOperands))
return failure();
state.addSuccessor(dest, destOperands);
return success();
}
static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) {
printer << spirv::BranchOp::getOperationName() << ' ';
printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.BranchConditionalOp // spv.BranchConditionalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -414,20 +414,6 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
}; };
} // end anonymous namespace. } // end anonymous namespace.
static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
Block *dest;
SmallVector<Value, 4> destOperands;
if (parser.parseSuccessorAndUseList(dest, destOperands))
return failure();
result.addSuccessor(dest, destOperands);
return success();
}
static void print(OpAsmPrinter &p, BranchOp op) {
p << "br ";
p.printSuccessorAndUseList(op.getOperation(), 0);
}
Block *BranchOp::getDest() { return getSuccessor(0); } Block *BranchOp::getDest() { return getSuccessor(0); }
void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); } void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); }
@ -810,42 +796,6 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
}; };
} // end anonymous namespace. } // end anonymous namespace.
static ParseResult parseCondBranchOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<Value, 4> destOperands;
Block *dest;
OpAsmParser::OperandType condInfo;
// Parse the condition.
Type int1Ty = parser.getBuilder().getI1Type();
if (parser.parseOperand(condInfo) || parser.parseComma() ||
parser.resolveOperand(condInfo, int1Ty, result.operands)) {
return parser.emitError(parser.getNameLoc(),
"expected condition type was boolean (i1)");
}
// Parse the true successor.
if (parser.parseSuccessorAndUseList(dest, destOperands))
return failure();
result.addSuccessor(dest, destOperands);
// Parse the false successor.
destOperands.clear();
if (parser.parseComma() ||
parser.parseSuccessorAndUseList(dest, destOperands))
return failure();
result.addSuccessor(dest, destOperands);
return success();
}
static void print(OpAsmPrinter &p, CondBranchOp op) {
p << "cond_br " << op.getCondition() << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
p << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
}
void CondBranchOp::getCanonicalizationPatterns( void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyConstCondBranchPred>(context); results.insert<SimplifyConstCondBranchPred>(context);

View File

@ -4423,6 +4423,15 @@ public:
return parser.parseSuccessorAndUseList(dest, operands); return parser.parseSuccessorAndUseList(dest, operands);
} }
/// Parse an optional operation successor and its operand list.
OptionalParseResult
parseOptionalSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value> &operands) override {
if (parser.getToken().isNot(Token::caret_identifier))
return llvm::None;
return parseSuccessorAndUseList(dest, operands);
}
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Type Parsing // Type Parsing
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//

View File

@ -24,8 +24,8 @@ func @branch_argument() -> () {
// ----- // -----
func @missing_accessor() -> () { func @missing_accessor() -> () {
// expected-error @+1 {{has incorrect number of successors: expected 1 but found 0}}
spv.Branch spv.Branch
// expected-error @+1 {{expected block name}}
} }
// ----- // -----

View File

@ -402,7 +402,6 @@ func @condbr_notbool() {
^bb0: ^bb0:
%a = "foo"() : () -> i32 // expected-note {{prior use here}} %a = "foo"() : () -> i32 // expected-note {{prior use here}}
cond_br %a, ^bb0, ^bb0 // expected-error {{use of value '%a' expects different type than prior uses: 'i1' vs 'i32'}} cond_br %a, ^bb0, ^bb0 // expected-error {{use of value '%a' expects different type than prior uses: 'i1' vs 'i32'}}
// expected-error@-1 {{expected condition type was boolean (i1)}}
} }
// ----- // -----

View File

@ -1139,4 +1139,9 @@ def FormatOperandEOp : FormatOperandBase<"format_operand_e_op", [{
$buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict
}]>; }]>;
def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
let successors = (successor VariadicSuccessor<AnySuccessor>:$targets);
let assemblyFormat = "$targets attr-dict";
}
#endif // TEST_OPS #endif // TEST_OPS

View File

@ -94,10 +94,18 @@ def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{
// results // results
// CHECK: error: 'results' directive can not be used as a top-level directive // CHECK: error: 'results' directive can not be used as a top-level directive
def DirectiveResultsInvalidA : TestFormat_Op<"operands_invalid_a", [{ def DirectiveResultsInvalidA : TestFormat_Op<"results_invalid_a", [{
results results
}]>; }]>;
//===----------------------------------------------------------------------===//
// successors
// CHECK: error: 'successors' is only valid as a top-level directive
def DirectiveSuccessorsInvalidA : TestFormat_Op<"successors_invalid_a", [{
type(successors)
}]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// type // type
@ -235,7 +243,7 @@ def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
// Variables // Variables
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CHECK: error: expected variable to refer to a argument or result // CHECK: error: expected variable to refer to a argument, result, or successor
def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{ def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{
$unknown_arg attr-dict $unknown_arg attr-dict
}]>; }]>;
@ -255,6 +263,18 @@ def VariableInvalidD : TestFormat_Op<"variable_invalid_d", [{
def VariableInvalidE : TestFormat_Op<"variable_invalid_e", [{ def VariableInvalidE : TestFormat_Op<"variable_invalid_e", [{
$result attr-dict $result attr-dict
}]>, Results<(outs I64:$result)>; }]>, Results<(outs I64:$result)>;
// CHECK: error: successor 'successor' is already bound
def VariableInvalidF : TestFormat_Op<"variable_invalid_f", [{
$successor $successor attr-dict
}]> {
let successors = (successor AnySuccessor:$successor);
}
// CHECK: error: successor 'successor' is already bound
def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{
successors $successor attr-dict
}]> {
let successors = (successor AnySuccessor:$successor);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Coverage Checks // Coverage Checks

View File

@ -41,3 +41,19 @@ test.format_operand_d_op %i64, %memref : memref<1xf64>
// CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64> // CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64>
test.format_operand_e_op %i64, %memref : i64, memref<1xf64> test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
"foo.successor_test_region"() ( {
^bb0:
// CHECK: test.format_successor_a_op ^bb1 {attr}
test.format_successor_a_op ^bb1 {attr}
^bb1:
// CHECK: test.format_successor_a_op ^bb1, ^bb2 {attr}
test.format_successor_a_op ^bb1, ^bb2 {attr}
^bb2:
// CHECK: test.format_successor_a_op {attr}
test.format_successor_a_op {attr}
}) { arg_names = ["i", "j", "k"] } : () -> ()

View File

@ -49,6 +49,7 @@ public:
FunctionalTypeDirective, FunctionalTypeDirective,
OperandsDirective, OperandsDirective,
ResultsDirective, ResultsDirective,
SuccessorsDirective,
TypeDirective, TypeDirective,
/// This element is a literal. /// This element is a literal.
@ -58,6 +59,7 @@ public:
AttributeVariable, AttributeVariable,
OperandVariable, OperandVariable,
ResultVariable, ResultVariable,
SuccessorVariable,
/// This element is an optional element. /// This element is an optional element.
Optional, Optional,
@ -105,6 +107,10 @@ using OperandVariable =
/// This class represents a variable that refers to a result. /// This class represents a variable that refers to a result.
using ResultVariable = using ResultVariable =
VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>; VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
/// This class represents a variable that refers to a successor.
using SuccessorVariable =
VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
} // end anonymous namespace } // end anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -126,6 +132,11 @@ using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
/// all of the results of an operation. /// all of the results of an operation.
using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>; using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
/// This class represents the `successors` directive. This directive represents
/// all of the successors of an operation.
using SuccessorsDirective =
DirectiveElement<Element::Kind::SuccessorsDirective>;
/// This class represents the `attr-dict` directive. This directive represents /// This class represents the `attr-dict` directive. This directive represents
/// the attribute dictionary of the operation. /// the attribute dictionary of the operation.
class AttrDictDirective class AttrDictDirective
@ -294,6 +305,8 @@ struct OperationFormat {
/// Generate the c++ to resolve the types of operands and results during /// Generate the c++ to resolve the types of operands and results during
/// parsing. /// parsing.
void genParserTypeResolution(Operator &op, OpMethodBody &body); void genParserTypeResolution(Operator &op, OpMethodBody &body);
/// Generate the c++ to resolve successors during parsing.
void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
/// Generate the operation printer from this format. /// Generate the operation printer from this format.
void genPrinter(Operator &op, OpClass &opClass); void genPrinter(Operator &op, OpClass &opClass);
@ -403,6 +416,51 @@ const char *const functionalTypeParserCode = R"(
{1}Types = {0}__{1}_functionType.getResults(); {1}Types = {0}__{1}_functionType.getResults();
)"; )";
/// The code snippet used to generate a parser call for a successor list.
///
/// {0}: The name for the successor list.
const char *successorListParserCode = R"(
SmallVector<std::pair<Block *, SmallVector<Value, 4>>, 2> {0}Successors;
{
Block *succ;
SmallVector<Value, 4> succOperands;
// Parse the first successor.
auto firstSucc = parser.parseOptionalSuccessorAndUseList(succ,
succOperands);
if (firstSucc.hasValue()) {
if (failed(*firstSucc))
return failure();
{0}Successors.emplace_back(succ, succOperands);
// Parse any trailing successors.
while (succeeded(parser.parseOptionalComma())) {
succOperands.clear();
if (parser.parseSuccessorAndUseList(succ, succOperands))
return failure();
{0}Successors.emplace_back(succ, succOperands);
}
}
}
)";
/// The code snippet used to generate a parser call for a successor.
///
/// {0}: The name of the successor.
const char *successorParserCode = R"(
Block *{0}Successor = nullptr;
SmallVector<Value, 4> {0}Operands;
if (parser.parseSuccessorAndUseList({0}Successor, {0}Operands))
return failure();
)";
/// The code snippet used to resolve a list of parsed successors.
///
/// {0}: The name of the successor list.
const char *resolveSuccessorListParserCode = R"(
for (auto &succAndArgs : {0}Successors)
result.addSuccessor(succAndArgs.first, succAndArgs.second);
)";
/// Get the name used for the type list for the given type directive operand. /// Get the name used for the type list for the given type directive operand.
/// 'isVariadic' is set to true if the operand has variadic types. /// 'isVariadic' is set to true if the operand has variadic types.
static StringRef getTypeListName(Element *arg, bool &isVariadic) { static StringRef getTypeListName(Element *arg, bool &isVariadic) {
@ -539,6 +597,10 @@ static void genElementParser(Element *element, OpMethodBody &body,
bool isVariadic = operand->getVar()->isVariadic(); bool isVariadic = operand->getVar()->isVariadic();
body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode, body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode,
operand->getVar()->name); operand->getVar()->name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
bool isVariadic = successor->getVar()->isVariadic();
body << formatv(isVariadic ? successorListParserCode : successorParserCode,
successor->getVar()->name);
/// Directives. /// Directives.
} else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) { } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
@ -551,6 +613,8 @@ static void genElementParser(Element *element, OpMethodBody &body,
<< " SmallVector<OpAsmParser::OperandType, 4> allOperands;\n" << " SmallVector<OpAsmParser::OperandType, 4> allOperands;\n"
<< " if (parser.parseOperandList(allOperands))\n" << " if (parser.parseOperandList(allOperands))\n"
<< " return failure();\n"; << " return failure();\n";
} else if (isa<SuccessorsDirective>(element)) {
body << llvm::formatv(successorListParserCode, "full");
} else if (auto *dir = dyn_cast<TypeDirective>(element)) { } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
bool isVariadic = false; bool isVariadic = false;
StringRef listName = getTypeListName(dir->getOperand(), isVariadic); StringRef listName = getTypeListName(dir->getOperand(), isVariadic);
@ -586,9 +650,10 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
for (auto &element : elements) for (auto &element : elements)
genElementParser(element.get(), body, attrTypeCtx); genElementParser(element.get(), body, attrTypeCtx);
// Generate the code to resolve the operand and result types now that they // Generate the code to resolve the operand/result types and successors now
// have been parsed. // that they have been parsed.
genParserTypeResolution(op, body); genParserTypeResolution(op, body);
genParserSuccessorResolution(op, body);
body << " return success();\n"; body << " return success();\n";
} }
@ -730,6 +795,28 @@ void OperationFormat::genParserTypeResolution(Operator &op,
} }
} }
void OperationFormat::genParserSuccessorResolution(Operator &op,
OpMethodBody &body) {
// Check for the case where all successors were parsed.
bool hasAllSuccessors = llvm::any_of(
elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
if (hasAllSuccessors) {
body << llvm::formatv(resolveSuccessorListParserCode, "full");
return;
}
// Otherwise, handle each successor individually.
for (const NamedSuccessor &successor : op.getSuccessors()) {
if (successor.isVariadic()) {
body << llvm::formatv(resolveSuccessorListParserCode, successor.name);
continue;
}
body << llvm::formatv(" result.addSuccessor({0}Successor, {0}Operands);\n",
successor.name);
}
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PrinterGen // PrinterGen
@ -790,8 +877,8 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
/// Generate the code for printing the given element. /// Generate the code for printing the given element.
static void genElementPrinter(Element *element, OpMethodBody &body, static void genElementPrinter(Element *element, OpMethodBody &body,
OperationFormat &fmt, bool &shouldEmitSpace, OperationFormat &fmt, Operator &op,
bool &lastWasPunctuation) { bool &shouldEmitSpace, bool &lastWasPunctuation) {
if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
lastWasPunctuation); lastWasPunctuation);
@ -808,7 +895,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
// Emit each of the elements. // Emit each of the elements.
for (Element &childElement : optional->getElements()) for (Element &childElement : optional->getElements())
genElementPrinter(&childElement, body, fmt, shouldEmitSpace, genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace,
lastWasPunctuation); lastWasPunctuation);
body << " }\n"; body << " }\n";
return; return;
@ -847,8 +934,30 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
body << " p.printAttribute(" << var->name << "Attr());\n"; body << " p.printAttribute(" << var->name << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) { } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
body << " p << " << operand->getVar()->name << "();\n"; body << " p << " << operand->getVar()->name << "();\n";
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
const NamedSuccessor *var = successor->getVar();
if (var->isVariadic()) {
body << " {\n"
<< " auto succRange = " << var->name << "();\n"
<< " auto opSuccBegin = getOperation()->successor_begin();\n"
<< " int i = succRange.begin() - opSuccBegin;\n"
<< " int e = i + succRange.size();\n"
<< " interleaveComma(llvm::seq<int>(i, e), p, [&](int i) {\n"
<< " p.printSuccessorAndUseList(*this, i);\n"
<< " });\n"
<< " }\n";
return;
}
unsigned index = successor->getVar() - op.successor_begin();
body << " p.printSuccessorAndUseList(*this, " << index << ");\n";
} else if (isa<OperandsDirective>(element)) { } else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n"; body << " p << getOperation()->getOperands();\n";
} else if (isa<SuccessorsDirective>(element)) {
body << " interleaveComma(llvm::seq<int>(0, "
"getOperation()->getNumSuccessors()), p, [&](int i) {"
<< " p.printSuccessorAndUseList(*this, i);"
<< " });\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) { } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
body << " p << "; body << " p << ";
genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
@ -879,7 +988,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
// punctuation. // punctuation.
bool shouldEmitSpace = true, lastWasPunctuation = false; bool shouldEmitSpace = true, lastWasPunctuation = false;
for (auto &element : elements) for (auto &element : elements)
genElementPrinter(element.get(), body, *this, shouldEmitSpace, genElementPrinter(element.get(), body, *this, op, shouldEmitSpace,
lastWasPunctuation); lastWasPunctuation);
} }
@ -911,6 +1020,7 @@ public:
kw_functional_type, kw_functional_type,
kw_operands, kw_operands,
kw_results, kw_results,
kw_successors,
kw_type, kw_type,
keyword_end, keyword_end,
@ -1094,6 +1204,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
.Case("functional-type", Token::kw_functional_type) .Case("functional-type", Token::kw_functional_type)
.Case("operands", Token::kw_operands) .Case("operands", Token::kw_operands)
.Case("results", Token::kw_results) .Case("results", Token::kw_results)
.Case("successors", Token::kw_successors)
.Case("type", Token::kw_type) .Case("type", Token::kw_type)
.Default(Token::identifier); .Default(Token::identifier);
return Token(kind, str); return Token(kind, str);
@ -1173,6 +1284,8 @@ private:
llvm::SMLoc loc, bool isTopLevel); llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseResultsDirective(std::unique_ptr<Element> &element, LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel); llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok, LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
bool isTopLevel); bool isTopLevel);
LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element); LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element);
@ -1211,9 +1324,11 @@ private:
// The following are various bits of format state used for verification // The following are various bits of format state used for verification
// during parsing. // during parsing.
bool hasAllOperands = false, hasAttrDict = false; bool hasAllOperands = false, hasAttrDict = false;
bool hasAllSuccessors = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes; llvm::SmallBitVector seenOperandTypes, seenResultTypes;
llvm::DenseSet<const NamedTypeConstraint *> seenOperands; llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
llvm::DenseSet<const NamedAttribute *> seenAttrs; llvm::DenseSet<const NamedAttribute *> seenAttrs;
llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
llvm::DenseSet<const NamedTypeConstraint *> optionalVariables; llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
}; };
} // end anonymous namespace } // end anonymous namespace
@ -1313,6 +1428,17 @@ LogicalResult FormatParser::parse() {
auto it = buildableTypes.insert({*builder, buildableTypes.size()}); auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.operandTypes[i].setBuilderIdx(it.first->second); fmt.operandTypes[i].setBuilderIdx(it.first->second);
} }
// Check that all of the successors are within the format.
if (!hasAllSuccessors) {
for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
const NamedSuccessor &successor = op.getSuccessor(i);
if (!seenSuccessors.count(&successor)) {
return emitError(loc, "format missing instance of successor #" +
Twine(i) + "('" + successor.name + "')");
}
}
}
return success(); return success();
} }
@ -1417,7 +1543,17 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
element = std::make_unique<ResultVariable>(result); element = std::make_unique<ResultVariable>(result);
return success(); return success();
} }
return emitError(loc, "expected variable to refer to a argument or result"); /// Successors.
if (const auto *successor = findArg(op.getSuccessors(), name)) {
if (!isTopLevel)
return emitError(loc, "successors can only be used at the top level");
if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
return emitError(loc, "successor '" + name + "' is already bound");
element = std::make_unique<SuccessorVariable>(successor);
return success();
}
return emitError(
loc, "expected variable to refer to a argument, result, or successor");
} }
LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element, LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
@ -1438,6 +1574,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel); return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_results: case Token::kw_results:
return parseResultsDirective(element, dirTok.getLoc(), isTopLevel); return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_successors:
return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_type: case Token::kw_type:
return parseTypeDirective(element, dirTok, isTopLevel); return parseTypeDirective(element, dirTok, isTopLevel);
@ -1624,6 +1762,19 @@ FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
return success(); return success();
} }
LogicalResult
FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel) {
if (!isTopLevel)
return emitError(loc,
"'successors' is only valid as a top-level directive");
if (hasAllSuccessors || !seenSuccessors.empty())
return emitError(loc, "'successors' directive creates overlap in format");
hasAllSuccessors = true;
element = std::make_unique<SuccessorsDirective>();
return success();
}
LogicalResult LogicalResult
FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok, FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
bool isTopLevel) { bool isTopLevel) {