forked from OSchip/llvm-project
Parse operations in ML functions. Add builder class for ML functions.
Refactors operation parsing to share functionality between CFG and ML functions. ML function construction now goes through a builder, similar to the way it is done for CFG functions. PiperOrigin-RevId: 204779279
This commit is contained in:
parent
8e8114a96d
commit
fc7d6dbe5e
|
@ -19,6 +19,8 @@
|
|||
#define MLIR_IR_BUILDERS_H
|
||||
|
||||
#include "mlir/IR/CFGFunction.h"
|
||||
#include "mlir/IR/MLFunction.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
|
@ -146,7 +148,54 @@ private:
|
|||
BasicBlock::iterator insertPoint;
|
||||
};
|
||||
|
||||
// TODO: MLFuncBuilder
|
||||
/// This class helps build an MLFunction. Statements that are created are
|
||||
/// automatically inserted at an insertion point or added to the current
|
||||
/// statement block.
|
||||
class MLFuncBuilder : public Builder {
|
||||
public:
|
||||
MLFuncBuilder(MLFunction *function) : Builder(function->getContext()) {}
|
||||
|
||||
MLFuncBuilder(StmtBlock *block) : MLFuncBuilder(block->getFunction()) {
|
||||
setInsertionPoint(block);
|
||||
}
|
||||
|
||||
/// Reset the insertion point to no location. Creating an operation without a
|
||||
/// set insertion point is an error, but this can still be useful when the
|
||||
/// current insertion point a builder refers to is being removed.
|
||||
void clearInsertionPoint() {
|
||||
this->block = nullptr;
|
||||
insertPoint = StmtBlock::iterator();
|
||||
}
|
||||
|
||||
/// Set the insertion point to the end of the specified block.
|
||||
void setInsertionPoint(StmtBlock *block) {
|
||||
this->block = block;
|
||||
insertPoint = block->end();
|
||||
}
|
||||
|
||||
OperationStmt *createOperation(Identifier name,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
auto op = new OperationStmt(name, attributes, context);
|
||||
block->getStatements().push_back(op);
|
||||
return op;
|
||||
}
|
||||
|
||||
ForStmt *createFor() {
|
||||
auto stmt = new ForStmt();
|
||||
block->getStatements().push_back(stmt);
|
||||
return stmt;
|
||||
}
|
||||
|
||||
IfStmt *createIf() {
|
||||
auto stmt = new IfStmt();
|
||||
block->getStatements().push_back(stmt);
|
||||
return stmt;
|
||||
}
|
||||
|
||||
private:
|
||||
StmtBlock *block = nullptr;
|
||||
StmtBlock::iterator insertPoint;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -258,8 +258,8 @@ void MLFunctionState::print(const StmtBlock *block) {
|
|||
|
||||
void MLFunctionState::print(const Statement *stmt) {
|
||||
switch (stmt->getKind()) {
|
||||
case Statement::Kind::Operation: // TODO
|
||||
llvm_unreachable("Operation statement is not yet implemented");
|
||||
case Statement::Kind::Operation:
|
||||
return print(cast<OperationStmt>(stmt));
|
||||
case Statement::Kind::For:
|
||||
return print(cast<ForStmt>(stmt));
|
||||
case Statement::Kind::If:
|
||||
|
@ -274,7 +274,7 @@ void MLFunctionState::print(const OperationStmt *stmt) {
|
|||
void MLFunctionState::print(const ForStmt *stmt) {
|
||||
os.indent(numSpaces) << "for {\n";
|
||||
print(static_cast<const StmtBlock *>(stmt));
|
||||
os.indent(numSpaces) << "}\n";
|
||||
os.indent(numSpaces) << "}";
|
||||
}
|
||||
|
||||
void MLFunctionState::print(const IfStmt *stmt) {
|
||||
|
@ -286,7 +286,6 @@ void MLFunctionState::print(const IfStmt *stmt) {
|
|||
print(stmt->getElseClause());
|
||||
os.indent(numSpaces) << "}";
|
||||
}
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -84,6 +84,9 @@ private:
|
|||
|
||||
namespace {
|
||||
|
||||
typedef std::function<Operation *(Identifier, ArrayRef<NamedAttribute>)>
|
||||
CreateOperationFunction;
|
||||
|
||||
/// This class implement support for parsing global entities like types and
|
||||
/// shared entities like SSA names. It is intended to be subclassed by
|
||||
/// specialized subparsers that include state, e.g. when a local symbol table.
|
||||
|
@ -167,6 +170,9 @@ public:
|
|||
ParseResult parseSSAUseAndType();
|
||||
ParseResult parseOptionalSSAUseAndTypeList(Token::Kind endToken);
|
||||
|
||||
// Operations
|
||||
ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
|
||||
|
||||
private:
|
||||
// The Parser is subclassed and reinstantiated. Do not add additional
|
||||
// non-trivial state here, add it to the ParserState class.
|
||||
|
@ -1217,6 +1223,70 @@ ParseResult Parser::parseOptionalSSAUseAndTypeList(Token::Kind endToken) {
|
|||
endToken, [&]() -> ParseResult { return parseSSAUseAndType(); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Parse the CFG or MLFunc operation.
|
||||
///
|
||||
/// TODO(clattner): This is a change from the MLIR spec as written, it is an
|
||||
/// experiment that will eliminate "builtin" instructions as a thing.
|
||||
///
|
||||
/// operation ::=
|
||||
/// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
|
||||
/// `:` function-type
|
||||
///
|
||||
ParseResult
|
||||
Parser::parseOperation(const CreateOperationFunction &createOpFunc) {
|
||||
auto loc = getToken().getLoc();
|
||||
|
||||
StringRef resultID;
|
||||
if (getToken().is(Token::percent_identifier)) {
|
||||
resultID = getTokenSpelling().drop_front();
|
||||
consumeToken(Token::percent_identifier);
|
||||
if (!consumeIf(Token::equal))
|
||||
return emitError("expected '=' after SSA name");
|
||||
}
|
||||
|
||||
if (getToken().isNot(Token::string))
|
||||
return emitError("expected operation name in quotes");
|
||||
|
||||
auto name = getToken().getStringValue();
|
||||
if (name.empty())
|
||||
return emitError("empty operation name is invalid");
|
||||
|
||||
consumeToken(Token::string);
|
||||
|
||||
if (!consumeIf(Token::l_paren))
|
||||
return emitError("expected '(' to start operand list");
|
||||
|
||||
// Parse the operand list.
|
||||
parseOptionalSSAUseList(Token::r_paren);
|
||||
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
if (getToken().is(Token::l_brace)) {
|
||||
if (parseAttributeDict(attributes))
|
||||
return ParseFailure;
|
||||
}
|
||||
|
||||
// TODO: Don't drop result name and operand names on the floor.
|
||||
auto nameId = builder.getIdentifier(name);
|
||||
|
||||
auto oper = createOpFunc(nameId, attributes);
|
||||
|
||||
if (!oper)
|
||||
return ParseFailure;
|
||||
|
||||
// We just parsed an operation. If it is a recognized one, verify that it
|
||||
// is structurally as we expect. If not, produce an error with a reasonable
|
||||
// source location.
|
||||
if (auto *opInfo = oper->getAbstractOperation(builder.getContext())) {
|
||||
if (auto error = opInfo->verifyInvariants(oper))
|
||||
return emitError(loc, error);
|
||||
}
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CFG Functions
|
||||
|
@ -1322,72 +1392,23 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
|
|||
// Set the insertion point to the block we want to insert new operations into.
|
||||
builder.setInsertionPoint(block);
|
||||
|
||||
auto createOpFunc = [this](Identifier name,
|
||||
ArrayRef<NamedAttribute> attrs) -> Operation * {
|
||||
return builder.createOperation(name, attrs);
|
||||
};
|
||||
|
||||
// Parse the list of operations that make up the body of the block.
|
||||
while (getToken().isNot(Token::kw_return, Token::kw_br)) {
|
||||
auto loc = getToken().getLoc();
|
||||
auto *inst = parseCFGOperation();
|
||||
if (!inst)
|
||||
if (parseOperation(createOpFunc))
|
||||
return ParseFailure;
|
||||
|
||||
// We just parsed an operation. If it is a recognized one, verify that it
|
||||
// is structurally as we expect. If not, produce an error with a reasonable
|
||||
// source location.
|
||||
if (auto *opInfo = inst->getAbstractOperation(builder.getContext()))
|
||||
if (auto error = opInfo->verifyInvariants(inst))
|
||||
return emitError(loc, error);
|
||||
}
|
||||
|
||||
auto *term = parseTerminator();
|
||||
if (!term)
|
||||
if (!parseTerminator())
|
||||
return ParseFailure;
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
/// Parse the CFG operation.
|
||||
///
|
||||
/// TODO(clattner): This is a change from the MLIR spec as written, it is an
|
||||
/// experiment that will eliminate "builtin" instructions as a thing.
|
||||
///
|
||||
/// cfg-operation ::=
|
||||
/// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
|
||||
/// `:` function-type
|
||||
///
|
||||
OperationInst *CFGFunctionParser::parseCFGOperation() {
|
||||
StringRef resultID;
|
||||
if (getToken().is(Token::percent_identifier)) {
|
||||
resultID = getTokenSpelling().drop_front();
|
||||
consumeToken();
|
||||
if (!consumeIf(Token::equal))
|
||||
return (emitError("expected '=' after SSA name"), nullptr);
|
||||
}
|
||||
|
||||
if (getToken().isNot(Token::string))
|
||||
return (emitError("expected operation name in quotes"), nullptr);
|
||||
|
||||
auto name = getToken().getStringValue();
|
||||
if (name.empty())
|
||||
return (emitError("empty operation name is invalid"), nullptr);
|
||||
|
||||
consumeToken(Token::string);
|
||||
|
||||
if (!consumeIf(Token::l_paren))
|
||||
return (emitError("expected '(' to start operand list"), nullptr);
|
||||
|
||||
// Parse the operand list.
|
||||
parseOptionalSSAUseList(Token::r_paren);
|
||||
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
if (getToken().is(Token::l_brace)) {
|
||||
if (parseAttributeDict(attributes))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO: Don't drop result name and operand names on the floor.
|
||||
auto nameId = builder.getIdentifier(name);
|
||||
return builder.createOperation(nameId, attributes);
|
||||
}
|
||||
|
||||
/// Parse the terminator instruction for a basic block.
|
||||
///
|
||||
/// terminator-stmt ::= `br` bb-id branch-use-list?
|
||||
|
@ -1424,22 +1445,22 @@ namespace {
|
|||
/// Refined parser for MLFunction bodies.
|
||||
class MLFunctionParser : public Parser {
|
||||
public:
|
||||
MLFunction *function;
|
||||
|
||||
/// This builder intentionally shadows the builder in the base class, with a
|
||||
/// more specific builder type.
|
||||
// TODO: MLFuncBuilder builder;
|
||||
|
||||
MLFunctionParser(ParserState &state, MLFunction *function)
|
||||
: Parser(state), function(function) {}
|
||||
: Parser(state), function(function), builder(function) {}
|
||||
|
||||
ParseResult parseFunctionBody();
|
||||
|
||||
private:
|
||||
Statement *parseStatement();
|
||||
ForStmt *parseForStmt();
|
||||
IfStmt *parseIfStmt();
|
||||
MLFunction *function;
|
||||
|
||||
/// This builder intentionally shadows the builder in the base class, with a
|
||||
/// more specific builder type.
|
||||
MLFuncBuilder builder;
|
||||
|
||||
ParseResult parseForStmt();
|
||||
ParseResult parseIfStmt();
|
||||
ParseResult parseElseClause(IfClause *elseClause);
|
||||
ParseResult parseStatements(StmtBlock *block);
|
||||
ParseResult parseStmtBlock(StmtBlock *block);
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -1448,19 +1469,14 @@ ParseResult MLFunctionParser::parseFunctionBody() {
|
|||
if (!consumeIf(Token::l_brace))
|
||||
return emitError("expected '{' in ML function");
|
||||
|
||||
// Make sure we have at least one statement.
|
||||
if (getToken().is(Token::r_brace))
|
||||
return emitError("ML function must end with return statement");
|
||||
|
||||
// Parse the list of instructions.
|
||||
while (!consumeIf(Token::kw_return)) {
|
||||
auto *stmt = parseStatement();
|
||||
if (!stmt)
|
||||
return ParseFailure;
|
||||
function->push_back(stmt);
|
||||
}
|
||||
// Parse statements in this function
|
||||
if (parseStatements(function))
|
||||
return ParseFailure;
|
||||
|
||||
if (!consumeIf(Token::kw_return))
|
||||
emitError("ML function must end with return statement");
|
||||
// TODO: parse return statement operands
|
||||
|
||||
if (!consumeIf(Token::r_brace))
|
||||
emitError("expected '}' in ML function");
|
||||
|
||||
|
@ -1469,42 +1485,23 @@ ParseResult MLFunctionParser::parseFunctionBody() {
|
|||
return ParseSuccess;
|
||||
}
|
||||
|
||||
/// Statement.
|
||||
///
|
||||
/// ml-stmt ::= instruction | ml-for-stmt | ml-if-stmt
|
||||
///
|
||||
/// TODO: fix terminology in MLSpec document. ML functions
|
||||
/// contain operation statements, not instructions.
|
||||
///
|
||||
Statement *MLFunctionParser::parseStatement() {
|
||||
switch (getToken().getKind()) {
|
||||
default:
|
||||
//TODO: parse OperationStmt
|
||||
return (emitError("expected statement"), nullptr);
|
||||
|
||||
case Token::kw_for:
|
||||
return parseForStmt();
|
||||
|
||||
case Token::kw_if:
|
||||
return parseIfStmt();
|
||||
}
|
||||
}
|
||||
|
||||
/// For statement.
|
||||
///
|
||||
/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
|
||||
/// (`step` integer-literal)? `{` ml-stmt* `}`
|
||||
///
|
||||
ForStmt *MLFunctionParser::parseForStmt() {
|
||||
ParseResult MLFunctionParser::parseForStmt() {
|
||||
consumeToken(Token::kw_for);
|
||||
|
||||
//TODO: parse loop header
|
||||
ForStmt *stmt = new ForStmt();
|
||||
if (parseStmtBlock(static_cast<StmtBlock *>(stmt))) {
|
||||
delete stmt;
|
||||
return nullptr;
|
||||
}
|
||||
return stmt;
|
||||
ForStmt *stmt = builder.createFor();
|
||||
|
||||
// If parsing of the for statement body fails
|
||||
// MLIR contains for statement with successfully parsed nested statements
|
||||
if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
|
||||
return ParseFailure;
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
/// If statement.
|
||||
|
@ -1514,45 +1511,69 @@ ForStmt *MLFunctionParser::parseForStmt() {
|
|||
/// ml-if-stmt ::= ml-if-head
|
||||
/// | ml-if-head `else` `{` ml-stmt* `}`
|
||||
///
|
||||
IfStmt *MLFunctionParser::parseIfStmt() {
|
||||
ParseResult MLFunctionParser::parseIfStmt() {
|
||||
consumeToken(Token::kw_if);
|
||||
if (!consumeIf(Token::l_paren))
|
||||
return (emitError("expected ("), nullptr);
|
||||
return emitError("expected (");
|
||||
|
||||
//TODO: parse condition
|
||||
|
||||
if (!consumeIf(Token::r_paren))
|
||||
return (emitError("expected )"), nullptr);
|
||||
return emitError("expected )");
|
||||
|
||||
IfStmt *ifStmt = new IfStmt();
|
||||
IfStmt *ifStmt = builder.createIf();
|
||||
IfClause *thenClause = ifStmt->getThenClause();
|
||||
if (parseStmtBlock(thenClause)) {
|
||||
delete ifStmt;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If parsing of the then or optional else clause fails MLIR contains
|
||||
// if statement with successfully parsed nested statements.
|
||||
if (parseStmtBlock(thenClause))
|
||||
return ParseFailure;
|
||||
|
||||
if (consumeIf(Token::kw_else)) {
|
||||
IfClause *elseClause = ifStmt->createElseClause();
|
||||
if (parseElseClause(elseClause)) {
|
||||
delete ifStmt;
|
||||
return nullptr;
|
||||
}
|
||||
if (parseElseClause(elseClause))
|
||||
return ParseFailure;
|
||||
}
|
||||
|
||||
return ifStmt;
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
|
||||
if (getToken().is(Token::kw_if)) {
|
||||
IfStmt *nextIf = parseIfStmt();
|
||||
if (!nextIf)
|
||||
return ParseFailure;
|
||||
elseClause->push_back(nextIf);
|
||||
return ParseSuccess;
|
||||
builder.setInsertionPoint(elseClause);
|
||||
return parseIfStmt();
|
||||
}
|
||||
|
||||
if (parseStmtBlock(elseClause))
|
||||
return ParseFailure;
|
||||
return parseStmtBlock(elseClause);
|
||||
}
|
||||
|
||||
///
|
||||
/// Parse a list of statements ending with `return` or `}`
|
||||
///
|
||||
ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
|
||||
auto createOpFunc = [this](Identifier name,
|
||||
ArrayRef<NamedAttribute> attrs) -> Operation * {
|
||||
return builder.createOperation(name, attrs);
|
||||
};
|
||||
|
||||
builder.setInsertionPoint(block);
|
||||
|
||||
while (getToken().isNot(Token::kw_return, Token::r_brace)) {
|
||||
switch (getToken().getKind()) {
|
||||
default:
|
||||
if (parseOperation(createOpFunc))
|
||||
return ParseFailure;
|
||||
break;
|
||||
case Token::kw_for:
|
||||
if (parseForStmt())
|
||||
return ParseFailure;
|
||||
break;
|
||||
case Token::kw_if:
|
||||
if (parseIfStmt())
|
||||
return ParseFailure;
|
||||
break;
|
||||
} // end switch
|
||||
}
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
@ -1564,12 +1585,11 @@ ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
|
|||
if (!consumeIf(Token::l_brace))
|
||||
return emitError("expected '{' before statement list");
|
||||
|
||||
while (!consumeIf(Token::r_brace)) {
|
||||
auto *stmt = parseStatement();
|
||||
if (!stmt)
|
||||
return ParseFailure;
|
||||
block->push_back(stmt);
|
||||
}
|
||||
if (parseStatements(block))
|
||||
return ParseFailure;
|
||||
|
||||
if (!consumeIf(Token::r_brace))
|
||||
return emitError("expected '}' at the end of the statement block");
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
|
|
@ -130,7 +130,7 @@ mlfunc @incomplete_for() {
|
|||
// -----
|
||||
|
||||
mlfunc @non_statement() {
|
||||
asd // expected-error {{expected statement}}
|
||||
asd // expected-error {{expected operation name in quotes}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -78,11 +78,21 @@ bb4: // CHECK: bb3:
|
|||
return // CHECK: return
|
||||
} // CHECK: }
|
||||
|
||||
// CHECK-LABEL: mlfunc @simpleMLF() {
|
||||
mlfunc @simpleMLF() {
|
||||
// CHECK-LABEL: mlfunc @emptyMLF() {
|
||||
mlfunc @emptyMLF() {
|
||||
return // CHECK: return
|
||||
} // CHECK: }
|
||||
|
||||
// CHECK-LABEL: mlfunc @mlfunc_with_ops() {
|
||||
mlfunc @mlfunc_with_ops() {
|
||||
// CHECK: dim xxx, 2 : sometype
|
||||
%a = "dim"(%42){index: 2}
|
||||
|
||||
// CHECK: addf xx, yy : sometype
|
||||
"addf"()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @loops() {
|
||||
mlfunc @loops() {
|
||||
for { // CHECK: for {
|
||||
|
|
Loading…
Reference in New Issue