Parse operations in ML functions. Add builder class for ML functions.

Refactors operation parsing to share functionality between CFG and ML functions. ML function construction now goes through a builder, similar to the way it is done for
CFG functions.

PiperOrigin-RevId: 204779279
This commit is contained in:
Tatiana Shpeisman 2018-07-16 11:47:09 -07:00 committed by jpienaar
parent 8e8114a96d
commit fc7d6dbe5e
5 changed files with 217 additions and 139 deletions

View File

@ -19,6 +19,8 @@
#define MLIR_IR_BUILDERS_H
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
namespace mlir {
class MLIRContext;
@ -146,7 +148,54 @@ private:
BasicBlock::iterator insertPoint;
};
// TODO: MLFuncBuilder
/// This class helps build an MLFunction. Statements that are created are
/// automatically inserted at an insertion point or added to the current
/// statement block.
class MLFuncBuilder : public Builder {
public:
MLFuncBuilder(MLFunction *function) : Builder(function->getContext()) {}
MLFuncBuilder(StmtBlock *block) : MLFuncBuilder(block->getFunction()) {
setInsertionPoint(block);
}
/// Reset the insertion point to no location. Creating an operation without a
/// set insertion point is an error, but this can still be useful when the
/// current insertion point a builder refers to is being removed.
void clearInsertionPoint() {
this->block = nullptr;
insertPoint = StmtBlock::iterator();
}
/// Set the insertion point to the end of the specified block.
void setInsertionPoint(StmtBlock *block) {
this->block = block;
insertPoint = block->end();
}
OperationStmt *createOperation(Identifier name,
ArrayRef<NamedAttribute> attributes) {
auto op = new OperationStmt(name, attributes, context);
block->getStatements().push_back(op);
return op;
}
ForStmt *createFor() {
auto stmt = new ForStmt();
block->getStatements().push_back(stmt);
return stmt;
}
IfStmt *createIf() {
auto stmt = new IfStmt();
block->getStatements().push_back(stmt);
return stmt;
}
private:
StmtBlock *block = nullptr;
StmtBlock::iterator insertPoint;
};
} // namespace mlir

View File

@ -258,8 +258,8 @@ void MLFunctionState::print(const StmtBlock *block) {
void MLFunctionState::print(const Statement *stmt) {
switch (stmt->getKind()) {
case Statement::Kind::Operation: // TODO
llvm_unreachable("Operation statement is not yet implemented");
case Statement::Kind::Operation:
return print(cast<OperationStmt>(stmt));
case Statement::Kind::For:
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
@ -274,7 +274,7 @@ void MLFunctionState::print(const OperationStmt *stmt) {
void MLFunctionState::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for {\n";
print(static_cast<const StmtBlock *>(stmt));
os.indent(numSpaces) << "}\n";
os.indent(numSpaces) << "}";
}
void MLFunctionState::print(const IfStmt *stmt) {
@ -286,7 +286,6 @@ void MLFunctionState::print(const IfStmt *stmt) {
print(stmt->getElseClause());
os.indent(numSpaces) << "}";
}
os << "\n";
}
//===----------------------------------------------------------------------===//

View File

@ -84,6 +84,9 @@ private:
namespace {
typedef std::function<Operation *(Identifier, ArrayRef<NamedAttribute>)>
CreateOperationFunction;
/// This class implement support for parsing global entities like types and
/// shared entities like SSA names. It is intended to be subclassed by
/// specialized subparsers that include state, e.g. when a local symbol table.
@ -167,6 +170,9 @@ public:
ParseResult parseSSAUseAndType();
ParseResult parseOptionalSSAUseAndTypeList(Token::Kind endToken);
// Operations
ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
private:
// The Parser is subclassed and reinstantiated. Do not add additional
// non-trivial state here, add it to the ParserState class.
@ -1217,6 +1223,70 @@ ParseResult Parser::parseOptionalSSAUseAndTypeList(Token::Kind endToken) {
endToken, [&]() -> ParseResult { return parseSSAUseAndType(); });
}
//===----------------------------------------------------------------------===//
// Operations
//===----------------------------------------------------------------------===//
/// Parse the CFG or MLFunc operation.
///
/// TODO(clattner): This is a change from the MLIR spec as written, it is an
/// experiment that will eliminate "builtin" instructions as a thing.
///
/// operation ::=
/// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
/// `:` function-type
///
ParseResult
Parser::parseOperation(const CreateOperationFunction &createOpFunc) {
auto loc = getToken().getLoc();
StringRef resultID;
if (getToken().is(Token::percent_identifier)) {
resultID = getTokenSpelling().drop_front();
consumeToken(Token::percent_identifier);
if (!consumeIf(Token::equal))
return emitError("expected '=' after SSA name");
}
if (getToken().isNot(Token::string))
return emitError("expected operation name in quotes");
auto name = getToken().getStringValue();
if (name.empty())
return emitError("empty operation name is invalid");
consumeToken(Token::string);
if (!consumeIf(Token::l_paren))
return emitError("expected '(' to start operand list");
// Parse the operand list.
parseOptionalSSAUseList(Token::r_paren);
SmallVector<NamedAttribute, 4> attributes;
if (getToken().is(Token::l_brace)) {
if (parseAttributeDict(attributes))
return ParseFailure;
}
// TODO: Don't drop result name and operand names on the floor.
auto nameId = builder.getIdentifier(name);
auto oper = createOpFunc(nameId, attributes);
if (!oper)
return ParseFailure;
// We just parsed an operation. If it is a recognized one, verify that it
// is structurally as we expect. If not, produce an error with a reasonable
// source location.
if (auto *opInfo = oper->getAbstractOperation(builder.getContext())) {
if (auto error = opInfo->verifyInvariants(oper))
return emitError(loc, error);
}
return ParseSuccess;
}
//===----------------------------------------------------------------------===//
// CFG Functions
@ -1322,72 +1392,23 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
// Set the insertion point to the block we want to insert new operations into.
builder.setInsertionPoint(block);
auto createOpFunc = [this](Identifier name,
ArrayRef<NamedAttribute> attrs) -> Operation * {
return builder.createOperation(name, attrs);
};
// Parse the list of operations that make up the body of the block.
while (getToken().isNot(Token::kw_return, Token::kw_br)) {
auto loc = getToken().getLoc();
auto *inst = parseCFGOperation();
if (!inst)
if (parseOperation(createOpFunc))
return ParseFailure;
// We just parsed an operation. If it is a recognized one, verify that it
// is structurally as we expect. If not, produce an error with a reasonable
// source location.
if (auto *opInfo = inst->getAbstractOperation(builder.getContext()))
if (auto error = opInfo->verifyInvariants(inst))
return emitError(loc, error);
}
auto *term = parseTerminator();
if (!term)
if (!parseTerminator())
return ParseFailure;
return ParseSuccess;
}
/// Parse the CFG operation.
///
/// TODO(clattner): This is a change from the MLIR spec as written, it is an
/// experiment that will eliminate "builtin" instructions as a thing.
///
/// cfg-operation ::=
/// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
/// `:` function-type
///
OperationInst *CFGFunctionParser::parseCFGOperation() {
StringRef resultID;
if (getToken().is(Token::percent_identifier)) {
resultID = getTokenSpelling().drop_front();
consumeToken();
if (!consumeIf(Token::equal))
return (emitError("expected '=' after SSA name"), nullptr);
}
if (getToken().isNot(Token::string))
return (emitError("expected operation name in quotes"), nullptr);
auto name = getToken().getStringValue();
if (name.empty())
return (emitError("empty operation name is invalid"), nullptr);
consumeToken(Token::string);
if (!consumeIf(Token::l_paren))
return (emitError("expected '(' to start operand list"), nullptr);
// Parse the operand list.
parseOptionalSSAUseList(Token::r_paren);
SmallVector<NamedAttribute, 4> attributes;
if (getToken().is(Token::l_brace)) {
if (parseAttributeDict(attributes))
return nullptr;
}
// TODO: Don't drop result name and operand names on the floor.
auto nameId = builder.getIdentifier(name);
return builder.createOperation(nameId, attributes);
}
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
@ -1424,22 +1445,22 @@ namespace {
/// Refined parser for MLFunction bodies.
class MLFunctionParser : public Parser {
public:
MLFunction *function;
/// This builder intentionally shadows the builder in the base class, with a
/// more specific builder type.
// TODO: MLFuncBuilder builder;
MLFunctionParser(ParserState &state, MLFunction *function)
: Parser(state), function(function) {}
: Parser(state), function(function), builder(function) {}
ParseResult parseFunctionBody();
private:
Statement *parseStatement();
ForStmt *parseForStmt();
IfStmt *parseIfStmt();
MLFunction *function;
/// This builder intentionally shadows the builder in the base class, with a
/// more specific builder type.
MLFuncBuilder builder;
ParseResult parseForStmt();
ParseResult parseIfStmt();
ParseResult parseElseClause(IfClause *elseClause);
ParseResult parseStatements(StmtBlock *block);
ParseResult parseStmtBlock(StmtBlock *block);
};
} // end anonymous namespace
@ -1448,19 +1469,14 @@ ParseResult MLFunctionParser::parseFunctionBody() {
if (!consumeIf(Token::l_brace))
return emitError("expected '{' in ML function");
// Make sure we have at least one statement.
if (getToken().is(Token::r_brace))
return emitError("ML function must end with return statement");
// Parse the list of instructions.
while (!consumeIf(Token::kw_return)) {
auto *stmt = parseStatement();
if (!stmt)
return ParseFailure;
function->push_back(stmt);
}
// Parse statements in this function
if (parseStatements(function))
return ParseFailure;
if (!consumeIf(Token::kw_return))
emitError("ML function must end with return statement");
// TODO: parse return statement operands
if (!consumeIf(Token::r_brace))
emitError("expected '}' in ML function");
@ -1469,42 +1485,23 @@ ParseResult MLFunctionParser::parseFunctionBody() {
return ParseSuccess;
}
/// Statement.
///
/// ml-stmt ::= instruction | ml-for-stmt | ml-if-stmt
///
/// TODO: fix terminology in MLSpec document. ML functions
/// contain operation statements, not instructions.
///
Statement *MLFunctionParser::parseStatement() {
switch (getToken().getKind()) {
default:
//TODO: parse OperationStmt
return (emitError("expected statement"), nullptr);
case Token::kw_for:
return parseForStmt();
case Token::kw_if:
return parseIfStmt();
}
}
/// For statement.
///
/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
/// (`step` integer-literal)? `{` ml-stmt* `}`
///
ForStmt *MLFunctionParser::parseForStmt() {
ParseResult MLFunctionParser::parseForStmt() {
consumeToken(Token::kw_for);
//TODO: parse loop header
ForStmt *stmt = new ForStmt();
if (parseStmtBlock(static_cast<StmtBlock *>(stmt))) {
delete stmt;
return nullptr;
}
return stmt;
ForStmt *stmt = builder.createFor();
// If parsing of the for statement body fails
// MLIR contains for statement with successfully parsed nested statements
if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
return ParseFailure;
return ParseSuccess;
}
/// If statement.
@ -1514,45 +1511,69 @@ ForStmt *MLFunctionParser::parseForStmt() {
/// ml-if-stmt ::= ml-if-head
/// | ml-if-head `else` `{` ml-stmt* `}`
///
IfStmt *MLFunctionParser::parseIfStmt() {
ParseResult MLFunctionParser::parseIfStmt() {
consumeToken(Token::kw_if);
if (!consumeIf(Token::l_paren))
return (emitError("expected ("), nullptr);
return emitError("expected (");
//TODO: parse condition
if (!consumeIf(Token::r_paren))
return (emitError("expected )"), nullptr);
return emitError("expected )");
IfStmt *ifStmt = new IfStmt();
IfStmt *ifStmt = builder.createIf();
IfClause *thenClause = ifStmt->getThenClause();
if (parseStmtBlock(thenClause)) {
delete ifStmt;
return nullptr;
}
// If parsing of the then or optional else clause fails MLIR contains
// if statement with successfully parsed nested statements.
if (parseStmtBlock(thenClause))
return ParseFailure;
if (consumeIf(Token::kw_else)) {
IfClause *elseClause = ifStmt->createElseClause();
if (parseElseClause(elseClause)) {
delete ifStmt;
return nullptr;
}
if (parseElseClause(elseClause))
return ParseFailure;
}
return ifStmt;
return ParseSuccess;
}
ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
if (getToken().is(Token::kw_if)) {
IfStmt *nextIf = parseIfStmt();
if (!nextIf)
return ParseFailure;
elseClause->push_back(nextIf);
return ParseSuccess;
builder.setInsertionPoint(elseClause);
return parseIfStmt();
}
if (parseStmtBlock(elseClause))
return ParseFailure;
return parseStmtBlock(elseClause);
}
///
/// Parse a list of statements ending with `return` or `}`
///
ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
auto createOpFunc = [this](Identifier name,
ArrayRef<NamedAttribute> attrs) -> Operation * {
return builder.createOperation(name, attrs);
};
builder.setInsertionPoint(block);
while (getToken().isNot(Token::kw_return, Token::r_brace)) {
switch (getToken().getKind()) {
default:
if (parseOperation(createOpFunc))
return ParseFailure;
break;
case Token::kw_for:
if (parseForStmt())
return ParseFailure;
break;
case Token::kw_if:
if (parseIfStmt())
return ParseFailure;
break;
} // end switch
}
return ParseSuccess;
}
@ -1564,12 +1585,11 @@ ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
if (!consumeIf(Token::l_brace))
return emitError("expected '{' before statement list");
while (!consumeIf(Token::r_brace)) {
auto *stmt = parseStatement();
if (!stmt)
return ParseFailure;
block->push_back(stmt);
}
if (parseStatements(block))
return ParseFailure;
if (!consumeIf(Token::r_brace))
return emitError("expected '}' at the end of the statement block");
return ParseSuccess;
}

View File

@ -130,7 +130,7 @@ mlfunc @incomplete_for() {
// -----
mlfunc @non_statement() {
asd // expected-error {{expected statement}}
asd // expected-error {{expected operation name in quotes}}
}
// -----

View File

@ -78,11 +78,21 @@ bb4: // CHECK: bb3:
return // CHECK: return
} // CHECK: }
// CHECK-LABEL: mlfunc @simpleMLF() {
mlfunc @simpleMLF() {
// CHECK-LABEL: mlfunc @emptyMLF() {
mlfunc @emptyMLF() {
return // CHECK: return
} // CHECK: }
// CHECK-LABEL: mlfunc @mlfunc_with_ops() {
mlfunc @mlfunc_with_ops() {
// CHECK: dim xxx, 2 : sometype
%a = "dim"(%42){index: 2}
// CHECK: addf xx, yy : sometype
"addf"()
return
}
// CHECK-LABEL: mlfunc @loops() {
mlfunc @loops() {
for { // CHECK: for {