Parse ML function arguments, return statement operands, and for statement loop header.

Loop bounds and presumed to be constants for now and are stored in ForStmt as affine constant expressions.  ML function arguments, return statement operands and loop variable name are dropped for now.

PiperOrigin-RevId: 205256208
This commit is contained in:
Tatiana Shpeisman 2018-07-19 09:52:39 -07:00 committed by jpienaar
parent 72c24e3e71
commit 6ada91db02
9 changed files with 212 additions and 45 deletions

View File

@ -184,11 +184,10 @@ public:
return op;
}
ForStmt *createFor() {
auto stmt = new ForStmt();
block->getStatements().push_back(stmt);
return stmt;
}
// Creates for statement. When step is not specified, it is set to 1.
ForStmt *createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
AffineConstantExpr *step = nullptr);
IfStmt *createIf() {
auto stmt = new IfStmt();

View File

@ -22,10 +22,11 @@
#ifndef MLIR_IR_STATEMENTS_H
#define MLIR_IR_STATEMENTS_H
#include "mlir/Support/LLVM.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Statement.h"
#include "mlir/IR/StmtBlock.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
@ -50,11 +51,20 @@ public:
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public StmtBlock {
public:
explicit ForStmt() : Statement(Kind::For), StmtBlock(StmtBlockKind::For) {}
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
explicit ForStmt(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound, AffineConstantExpr *step)
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
//TODO: delete nested statements or assert that they are gone.
~ForStmt() {}
// TODO: represent loop variable, bounds and step
// TODO: represent induction variable
AffineConstantExpr *getLowerBound() const { return lowerBound; }
AffineConstantExpr *getUpperBound() const { return upperBound; }
AffineConstantExpr *getStep() const { return step; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
@ -64,9 +74,14 @@ public:
static bool classof(const StmtBlock *block) {
return block->getStmtBlockKind() == StmtBlockKind::For;
}
private:
AffineConstantExpr *lowerBound;
AffineConstantExpr *upperBound;
AffineConstantExpr *step;
};
/// If clause represents statements contained within then or else clause
/// An if clause represents statements contained within a then or an else clause
/// of an if statement.
class IfClause : public StmtBlock {
public:

View File

@ -594,7 +594,12 @@ void MLFunctionState::print(const Statement *stmt) {
void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
void MLFunctionState::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for {\n";
os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
os << " to " << *stmt->getUpperBound();
if (stmt->getStep()->getValue() != 1)
os << " step " << *stmt->getStep();
os << " {\n";
print(static_cast<const StmtBlock *>(stmt));
os.indent(numSpaces) << "}";
}

View File

@ -143,3 +143,17 @@ AffineExpr *Builder::getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
AffineExpr *Builder::getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
return AffineCeilDivExpr::get(lhs, rhs, context);
}
//===----------------------------------------------------------------------===//
// Statements
//===----------------------------------------------------------------------===//
ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
AffineConstantExpr *step) {
if (!step)
step = getConstantExpr(1);
auto stmt = new ForStmt(lowerBound, upperBound, step);
block->getStatements().push_back(stmt);
return stmt;
}

View File

@ -1581,6 +1581,7 @@ private:
MLFuncBuilder builder;
ParseResult parseForStmt();
AffineConstantExpr *parseIntConstant();
ParseResult parseIfStmt();
ParseResult parseElseClause(IfClause *elseClause);
ParseResult parseStatements(StmtBlock *block);
@ -1598,10 +1599,11 @@ ParseResult MLFunctionParser::parseFunctionBody() {
if (!consumeIf(Token::kw_return))
emitError("ML function must end with return statement");
// TODO: parse return statement operands
if (!consumeIf(Token::r_brace))
emitError("expected '}' in ML function");
// TODO: store return operands in the IR.
SmallVector<SSAUseInfo, 4> dummyUseInfo;
if (parseOptionalSSAUseList(Token::r_brace, dummyUseInfo))
return ParseFailure;
getModule()->functionList.push_back(function);
@ -1616,17 +1618,66 @@ ParseResult MLFunctionParser::parseFunctionBody() {
ParseResult MLFunctionParser::parseForStmt() {
consumeToken(Token::kw_for);
//TODO: parse loop header
ForStmt *stmt = builder.createFor();
// Parse induction variable
if (getToken().isNot(Token::percent_identifier))
return emitError("expected SSA identifier for the loop variable");
// If parsing of the for statement body fails
// MLIR contains for statement with successfully parsed nested statements
// TODO: create SSA value definition from name
StringRef name = getTokenSpelling().drop_front();
(void)name;
consumeToken(Token::percent_identifier);
if (!consumeIf(Token::equal))
return emitError("expected =");
// Parse loop bounds
AffineConstantExpr *lowerBound = parseIntConstant();
if (!lowerBound)
return ParseFailure;
if (!consumeIf(Token::kw_to))
return emitError("expected 'to' between bounds");
AffineConstantExpr *upperBound = parseIntConstant();
if (!upperBound)
return ParseFailure;
// Parse step
AffineConstantExpr *step = nullptr;
if (consumeIf(Token::kw_step)) {
step = parseIntConstant();
if (!step)
return ParseFailure;
}
// Create for statement.
ForStmt *stmt = builder.createFor(lowerBound, upperBound, step);
// If parsing of the for statement body fails,
// MLIR contains for statement with those nested statements that have been
// successfully parsed.
if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
return ParseFailure;
return ParseSuccess;
}
// This method is temporary workaround to parse simple loop bounds and
// step.
// TODO: remove this method once it's no longer used.
AffineConstantExpr *MLFunctionParser::parseIntConstant() {
if (getToken().isNot(Token::integer))
return (emitError("expected non-negative integer for now"), nullptr);
auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0) {
return (emitError("constant too large for affineint"), nullptr);
}
consumeToken(Token::integer);
return builder.getConstantExpr((int64_t)val.getValue());
}
/// If statement.
///
/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
@ -1642,13 +1693,14 @@ ParseResult MLFunctionParser::parseIfStmt() {
//TODO: parse condition
if (!consumeIf(Token::r_paren))
return emitError("expected )");
return emitError("expected ')'");
IfStmt *ifStmt = builder.createIf();
IfClause *thenClause = ifStmt->getThenClause();
// If parsing of the then or optional else clause fails MLIR contains
// if statement with successfully parsed nested statements.
// When parsing of an if statement body fails, the IR contains
// the if statement with the portion of the body that has been
// successfully parsed.
if (parseStmtBlock(thenClause))
return ParseFailure;
@ -1735,7 +1787,10 @@ private:
ParseResult parseAffineMapDef();
// Functions.
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type);
ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
SmallVectorImpl<StringRef> &argNames);
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
SmallVectorImpl<StringRef> *argNames);
ParseResult parseExtFunc();
ParseResult parseCFGFunc();
ParseResult parseMLFunc();
@ -1769,14 +1824,50 @@ ParseResult ModuleParser::parseAffineMapDef() {
return ParseSuccess;
}
/// Parse a (possibly empty) list of MLFunction arguments with types.
///
/// ml-argument ::= ssa-id `:` type
/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
///
ParseResult
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
SmallVectorImpl<StringRef> &argNames) {
auto parseElt = [&]() -> ParseResult {
// Parse argument name
if (getToken().isNot(Token::percent_identifier))
return emitError("expected SSA identifier");
StringRef name = getTokenSpelling().drop_front();
consumeToken(Token::percent_identifier);
argNames.push_back(name);
if (!consumeIf(Token::colon))
return emitError("expected ':'");
// Parse argument type
auto elt = parseType();
if (!elt)
return ParseFailure;
argTypes.push_back(elt);
return ParseSuccess;
};
if (!consumeIf(Token::l_paren))
llvm_unreachable("expected '('");
return parseCommaSeparatedList(Token::r_paren, parseElt);
}
/// Parse a function signature, starting with a name and including the parameter
/// list.
///
/// argument-list ::= type (`,` type)* | /*empty*/
/// argument-list ::= type (`,` type)* | /*empty*/ | ml-argument-list
/// function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
///
ParseResult ModuleParser::parseFunctionSignature(StringRef &name,
FunctionType *&type) {
ParseResult
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
SmallVectorImpl<StringRef> *argNames) {
if (getToken().isNot(Token::at_identifier))
return emitError("expected a function identifier like '@foo'");
@ -1786,8 +1877,15 @@ ParseResult ModuleParser::parseFunctionSignature(StringRef &name,
if (getToken().isNot(Token::l_paren))
return emitError("expected '(' in function signature");
SmallVector<Type *, 4> arguments;
if (parseTypeList(arguments))
SmallVector<Type *, 4> argTypes;
ParseResult parseResult;
if (argNames)
parseResult = parseMLArgumentList(argTypes, *argNames);
else
parseResult = parseTypeList(argTypes);
if (parseResult)
return ParseFailure;
// Parse the return type if present.
@ -1796,7 +1894,7 @@ ParseResult ModuleParser::parseFunctionSignature(StringRef &name,
if (parseTypeList(results))
return ParseFailure;
}
type = builder.getFunctionType(arguments, results);
type = builder.getFunctionType(argTypes, results);
return ParseSuccess;
}
@ -1809,7 +1907,7 @@ ParseResult ModuleParser::parseExtFunc() {
StringRef name;
FunctionType *type = nullptr;
if (parseFunctionSignature(name, type))
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure;
// Okay, the external function definition was parsed correctly.
@ -1826,7 +1924,7 @@ ParseResult ModuleParser::parseCFGFunc() {
StringRef name;
FunctionType *type = nullptr;
if (parseFunctionSignature(name, type))
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure;
// Okay, the CFG function signature was parsed correctly, create the function.
@ -1844,10 +1942,11 @@ ParseResult ModuleParser::parseMLFunc() {
StringRef name;
FunctionType *type = nullptr;
SmallVector<StringRef, 4> argNames;
// FIXME: Parse ML function signature (args + types)
// by passing pointer to SmallVector<identifier> into parseFunctionSignature
if (parseFunctionSignature(name, type))
if (parseFunctionSignature(name, type, &argNames))
return ParseFailure;
// Okay, the ML function signature was parsed correctly, create the function.

View File

@ -76,7 +76,7 @@ public:
/// return None.
Optional<unsigned> getUnsignedIntegerValue() const;
/// For an integer token, return its value as an int64_t. If it doesn't fit,
/// For an integer token, return its value as an uint64_t. If it doesn't fit,
/// return None.
Optional<uint64_t> getUInt64IntegerValue() const;

View File

@ -104,7 +104,9 @@ TOK_KEYWORD(mlfunc)
TOK_KEYWORD(mod)
TOK_KEYWORD(return)
TOK_KEYWORD(size)
TOK_KEYWORD(step)
TOK_KEYWORD(tensor)
TOK_KEYWORD(to)
TOK_KEYWORD(true)
TOK_KEYWORD(vector)

View File

@ -130,12 +130,24 @@ extfunc @illegaltype(i0) // expected-error {{invalid integer width}}
// -----
mlfunc @malformed_for() {
for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}}
}
}
// -----
mlfunc @incomplete_for() {
for
for %i = 1 to 10 step 2
} // expected-error {{expected '{' before statement list}}
// -----
mlfunc @nonconstant_step(%1 : i32) {
for %2 = 1 to 5 step %1 { // expected-error {{expected non-negative integer for now}}
// -----
mlfunc @non_statement() {
asd // expected-error {{expected operation name in quotes}}
}
@ -160,7 +172,6 @@ bb40:
return
}
// -----
cfgfunc @redef() {
@ -168,4 +179,16 @@ bb42:
%x = "dim"(){index: 0} : ()->i32
%x = "dim"(){index: 0} : ()->i32 // expected-error {{redefinition of SSA value %x}}
return
}
}
mlfunc @missing_rbrace() {
return %a
mlfunc @d {return} // expected-error {{expected ',' or '}'}}
// -----
mlfunc @malformed_type(%a : intt) { // expected-error {{expected type}}
}
// -----

View File

@ -96,22 +96,32 @@ mlfunc @emptyMLF() {
return // CHECK: return
} // CHECK: }
// CHECK-LABEL: mlfunc @mlfunc_with_args(f16) {
mlfunc @mlfunc_with_args(%a : f16) {
return %a // CHECK: return
}
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops() {
cfgfunc @cfgfunc_with_ops() {
bb0:
%t = "getTensor"() : () -> tensor<4x4x?xf32>
// CHECK: dim xxx, 2 : sometype
%a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
// CHECK: addf xx, yy : sometype
"addf"() : () -> ()
// CHECK: return
return
}
// CHECK-LABEL: mlfunc @loops() {
mlfunc @loops() {
for { // CHECK: for {
for { // CHECK: for {
// CHECK: for x = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
// CHECK: for x = 1 to 200 {
for %j = 1 to 200 {
} // CHECK: }
} // CHECK: }
return // CHECK: return
@ -119,14 +129,14 @@ mlfunc @loops() {
// CHECK-LABEL: mlfunc @ifstmt() {
mlfunc @ifstmt() {
for { // CHECK for {
if () { // CHECK if () {
} else if () { // CHECK } else if () {
} else { // CHECK } else {
} // CHECK }
} // CHECK }
return // CHECK return
} // CHECK }
for %i = 1 to 10 { // CHECK for x = 1 to 10 {
if () { // CHECK if () {
} else if () { // CHECK } else if () {
} else { // CHECK } else {
} // CHECK }
} // CHECK }
return // CHECK return
} // CHECK }
// CHECK-LABEL: cfgfunc @attributes() {
cfgfunc @attributes() {