forked from OSchip/llvm-project
Parse ML function arguments, return statement operands, and for statement loop header.
Loop bounds and presumed to be constants for now and are stored in ForStmt as affine constant expressions. ML function arguments, return statement operands and loop variable name are dropped for now. PiperOrigin-RevId: 205256208
This commit is contained in:
parent
72c24e3e71
commit
6ada91db02
|
@ -184,11 +184,10 @@ public:
|
|||
return op;
|
||||
}
|
||||
|
||||
ForStmt *createFor() {
|
||||
auto stmt = new ForStmt();
|
||||
block->getStatements().push_back(stmt);
|
||||
return stmt;
|
||||
}
|
||||
// Creates for statement. When step is not specified, it is set to 1.
|
||||
ForStmt *createFor(AffineConstantExpr *lowerBound,
|
||||
AffineConstantExpr *upperBound,
|
||||
AffineConstantExpr *step = nullptr);
|
||||
|
||||
IfStmt *createIf() {
|
||||
auto stmt = new IfStmt();
|
||||
|
|
|
@ -22,10 +22,11 @@
|
|||
#ifndef MLIR_IR_STATEMENTS_H
|
||||
#define MLIR_IR_STATEMENTS_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Statement.h"
|
||||
#include "mlir/IR/StmtBlock.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -50,11 +51,20 @@ public:
|
|||
/// For statement represents an affine loop nest.
|
||||
class ForStmt : public Statement, public StmtBlock {
|
||||
public:
|
||||
explicit ForStmt() : Statement(Kind::For), StmtBlock(StmtBlockKind::For) {}
|
||||
// TODO: lower and upper bounds should be affine maps with
|
||||
// dimension and symbol use lists.
|
||||
explicit ForStmt(AffineConstantExpr *lowerBound,
|
||||
AffineConstantExpr *upperBound, AffineConstantExpr *step)
|
||||
: Statement(Kind::For), StmtBlock(StmtBlockKind::For),
|
||||
lowerBound(lowerBound), upperBound(upperBound), step(step) {}
|
||||
|
||||
//TODO: delete nested statements or assert that they are gone.
|
||||
~ForStmt() {}
|
||||
|
||||
// TODO: represent loop variable, bounds and step
|
||||
// TODO: represent induction variable
|
||||
AffineConstantExpr *getLowerBound() const { return lowerBound; }
|
||||
AffineConstantExpr *getUpperBound() const { return upperBound; }
|
||||
AffineConstantExpr *getStep() const { return step; }
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Statement *stmt) {
|
||||
|
@ -64,9 +74,14 @@ public:
|
|||
static bool classof(const StmtBlock *block) {
|
||||
return block->getStmtBlockKind() == StmtBlockKind::For;
|
||||
}
|
||||
|
||||
private:
|
||||
AffineConstantExpr *lowerBound;
|
||||
AffineConstantExpr *upperBound;
|
||||
AffineConstantExpr *step;
|
||||
};
|
||||
|
||||
/// If clause represents statements contained within then or else clause
|
||||
/// An if clause represents statements contained within a then or an else clause
|
||||
/// of an if statement.
|
||||
class IfClause : public StmtBlock {
|
||||
public:
|
||||
|
|
|
@ -594,7 +594,12 @@ void MLFunctionState::print(const Statement *stmt) {
|
|||
void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
|
||||
|
||||
void MLFunctionState::print(const ForStmt *stmt) {
|
||||
os.indent(numSpaces) << "for {\n";
|
||||
os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
|
||||
os << " to " << *stmt->getUpperBound();
|
||||
if (stmt->getStep()->getValue() != 1)
|
||||
os << " step " << *stmt->getStep();
|
||||
|
||||
os << " {\n";
|
||||
print(static_cast<const StmtBlock *>(stmt));
|
||||
os.indent(numSpaces) << "}";
|
||||
}
|
||||
|
|
|
@ -143,3 +143,17 @@ AffineExpr *Builder::getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
|
|||
AffineExpr *Builder::getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
|
||||
return AffineCeilDivExpr::get(lhs, rhs, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Statements
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
|
||||
AffineConstantExpr *upperBound,
|
||||
AffineConstantExpr *step) {
|
||||
if (!step)
|
||||
step = getConstantExpr(1);
|
||||
auto stmt = new ForStmt(lowerBound, upperBound, step);
|
||||
block->getStatements().push_back(stmt);
|
||||
return stmt;
|
||||
}
|
||||
|
|
|
@ -1581,6 +1581,7 @@ private:
|
|||
MLFuncBuilder builder;
|
||||
|
||||
ParseResult parseForStmt();
|
||||
AffineConstantExpr *parseIntConstant();
|
||||
ParseResult parseIfStmt();
|
||||
ParseResult parseElseClause(IfClause *elseClause);
|
||||
ParseResult parseStatements(StmtBlock *block);
|
||||
|
@ -1598,10 +1599,11 @@ ParseResult MLFunctionParser::parseFunctionBody() {
|
|||
|
||||
if (!consumeIf(Token::kw_return))
|
||||
emitError("ML function must end with return statement");
|
||||
// TODO: parse return statement operands
|
||||
|
||||
if (!consumeIf(Token::r_brace))
|
||||
emitError("expected '}' in ML function");
|
||||
// TODO: store return operands in the IR.
|
||||
SmallVector<SSAUseInfo, 4> dummyUseInfo;
|
||||
if (parseOptionalSSAUseList(Token::r_brace, dummyUseInfo))
|
||||
return ParseFailure;
|
||||
|
||||
getModule()->functionList.push_back(function);
|
||||
|
||||
|
@ -1616,17 +1618,66 @@ ParseResult MLFunctionParser::parseFunctionBody() {
|
|||
ParseResult MLFunctionParser::parseForStmt() {
|
||||
consumeToken(Token::kw_for);
|
||||
|
||||
//TODO: parse loop header
|
||||
ForStmt *stmt = builder.createFor();
|
||||
// Parse induction variable
|
||||
if (getToken().isNot(Token::percent_identifier))
|
||||
return emitError("expected SSA identifier for the loop variable");
|
||||
|
||||
// If parsing of the for statement body fails
|
||||
// MLIR contains for statement with successfully parsed nested statements
|
||||
// TODO: create SSA value definition from name
|
||||
StringRef name = getTokenSpelling().drop_front();
|
||||
(void)name;
|
||||
|
||||
consumeToken(Token::percent_identifier);
|
||||
|
||||
if (!consumeIf(Token::equal))
|
||||
return emitError("expected =");
|
||||
|
||||
// Parse loop bounds
|
||||
AffineConstantExpr *lowerBound = parseIntConstant();
|
||||
if (!lowerBound)
|
||||
return ParseFailure;
|
||||
|
||||
if (!consumeIf(Token::kw_to))
|
||||
return emitError("expected 'to' between bounds");
|
||||
|
||||
AffineConstantExpr *upperBound = parseIntConstant();
|
||||
if (!upperBound)
|
||||
return ParseFailure;
|
||||
|
||||
// Parse step
|
||||
AffineConstantExpr *step = nullptr;
|
||||
if (consumeIf(Token::kw_step)) {
|
||||
step = parseIntConstant();
|
||||
if (!step)
|
||||
return ParseFailure;
|
||||
}
|
||||
|
||||
// Create for statement.
|
||||
ForStmt *stmt = builder.createFor(lowerBound, upperBound, step);
|
||||
|
||||
// If parsing of the for statement body fails,
|
||||
// MLIR contains for statement with those nested statements that have been
|
||||
// successfully parsed.
|
||||
if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
|
||||
return ParseFailure;
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
// This method is temporary workaround to parse simple loop bounds and
|
||||
// step.
|
||||
// TODO: remove this method once it's no longer used.
|
||||
AffineConstantExpr *MLFunctionParser::parseIntConstant() {
|
||||
if (getToken().isNot(Token::integer))
|
||||
return (emitError("expected non-negative integer for now"), nullptr);
|
||||
|
||||
auto val = getToken().getUInt64IntegerValue();
|
||||
if (!val.hasValue() || (int64_t)val.getValue() < 0) {
|
||||
return (emitError("constant too large for affineint"), nullptr);
|
||||
}
|
||||
consumeToken(Token::integer);
|
||||
return builder.getConstantExpr((int64_t)val.getValue());
|
||||
}
|
||||
|
||||
/// If statement.
|
||||
///
|
||||
/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
|
||||
|
@ -1642,13 +1693,14 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
//TODO: parse condition
|
||||
|
||||
if (!consumeIf(Token::r_paren))
|
||||
return emitError("expected )");
|
||||
return emitError("expected ')'");
|
||||
|
||||
IfStmt *ifStmt = builder.createIf();
|
||||
IfClause *thenClause = ifStmt->getThenClause();
|
||||
|
||||
// If parsing of the then or optional else clause fails MLIR contains
|
||||
// if statement with successfully parsed nested statements.
|
||||
// When parsing of an if statement body fails, the IR contains
|
||||
// the if statement with the portion of the body that has been
|
||||
// successfully parsed.
|
||||
if (parseStmtBlock(thenClause))
|
||||
return ParseFailure;
|
||||
|
||||
|
@ -1735,7 +1787,10 @@ private:
|
|||
ParseResult parseAffineMapDef();
|
||||
|
||||
// Functions.
|
||||
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type);
|
||||
ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
|
||||
SmallVectorImpl<StringRef> &argNames);
|
||||
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
|
||||
SmallVectorImpl<StringRef> *argNames);
|
||||
ParseResult parseExtFunc();
|
||||
ParseResult parseCFGFunc();
|
||||
ParseResult parseMLFunc();
|
||||
|
@ -1769,14 +1824,50 @@ ParseResult ModuleParser::parseAffineMapDef() {
|
|||
return ParseSuccess;
|
||||
}
|
||||
|
||||
/// Parse a (possibly empty) list of MLFunction arguments with types.
|
||||
///
|
||||
/// ml-argument ::= ssa-id `:` type
|
||||
/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
|
||||
///
|
||||
ParseResult
|
||||
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
|
||||
SmallVectorImpl<StringRef> &argNames) {
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
// Parse argument name
|
||||
if (getToken().isNot(Token::percent_identifier))
|
||||
return emitError("expected SSA identifier");
|
||||
|
||||
StringRef name = getTokenSpelling().drop_front();
|
||||
consumeToken(Token::percent_identifier);
|
||||
argNames.push_back(name);
|
||||
|
||||
if (!consumeIf(Token::colon))
|
||||
return emitError("expected ':'");
|
||||
|
||||
// Parse argument type
|
||||
auto elt = parseType();
|
||||
if (!elt)
|
||||
return ParseFailure;
|
||||
argTypes.push_back(elt);
|
||||
|
||||
return ParseSuccess;
|
||||
};
|
||||
|
||||
if (!consumeIf(Token::l_paren))
|
||||
llvm_unreachable("expected '('");
|
||||
|
||||
return parseCommaSeparatedList(Token::r_paren, parseElt);
|
||||
}
|
||||
|
||||
/// Parse a function signature, starting with a name and including the parameter
|
||||
/// list.
|
||||
///
|
||||
/// argument-list ::= type (`,` type)* | /*empty*/
|
||||
/// argument-list ::= type (`,` type)* | /*empty*/ | ml-argument-list
|
||||
/// function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
|
||||
///
|
||||
ParseResult ModuleParser::parseFunctionSignature(StringRef &name,
|
||||
FunctionType *&type) {
|
||||
ParseResult
|
||||
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
|
||||
SmallVectorImpl<StringRef> *argNames) {
|
||||
if (getToken().isNot(Token::at_identifier))
|
||||
return emitError("expected a function identifier like '@foo'");
|
||||
|
||||
|
@ -1786,8 +1877,15 @@ ParseResult ModuleParser::parseFunctionSignature(StringRef &name,
|
|||
if (getToken().isNot(Token::l_paren))
|
||||
return emitError("expected '(' in function signature");
|
||||
|
||||
SmallVector<Type *, 4> arguments;
|
||||
if (parseTypeList(arguments))
|
||||
SmallVector<Type *, 4> argTypes;
|
||||
ParseResult parseResult;
|
||||
|
||||
if (argNames)
|
||||
parseResult = parseMLArgumentList(argTypes, *argNames);
|
||||
else
|
||||
parseResult = parseTypeList(argTypes);
|
||||
|
||||
if (parseResult)
|
||||
return ParseFailure;
|
||||
|
||||
// Parse the return type if present.
|
||||
|
@ -1796,7 +1894,7 @@ ParseResult ModuleParser::parseFunctionSignature(StringRef &name,
|
|||
if (parseTypeList(results))
|
||||
return ParseFailure;
|
||||
}
|
||||
type = builder.getFunctionType(arguments, results);
|
||||
type = builder.getFunctionType(argTypes, results);
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
|
@ -1809,7 +1907,7 @@ ParseResult ModuleParser::parseExtFunc() {
|
|||
|
||||
StringRef name;
|
||||
FunctionType *type = nullptr;
|
||||
if (parseFunctionSignature(name, type))
|
||||
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
|
||||
return ParseFailure;
|
||||
|
||||
// Okay, the external function definition was parsed correctly.
|
||||
|
@ -1826,7 +1924,7 @@ ParseResult ModuleParser::parseCFGFunc() {
|
|||
|
||||
StringRef name;
|
||||
FunctionType *type = nullptr;
|
||||
if (parseFunctionSignature(name, type))
|
||||
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
|
||||
return ParseFailure;
|
||||
|
||||
// Okay, the CFG function signature was parsed correctly, create the function.
|
||||
|
@ -1844,10 +1942,11 @@ ParseResult ModuleParser::parseMLFunc() {
|
|||
|
||||
StringRef name;
|
||||
FunctionType *type = nullptr;
|
||||
|
||||
SmallVector<StringRef, 4> argNames;
|
||||
// FIXME: Parse ML function signature (args + types)
|
||||
// by passing pointer to SmallVector<identifier> into parseFunctionSignature
|
||||
if (parseFunctionSignature(name, type))
|
||||
|
||||
if (parseFunctionSignature(name, type, &argNames))
|
||||
return ParseFailure;
|
||||
|
||||
// Okay, the ML function signature was parsed correctly, create the function.
|
||||
|
|
|
@ -76,7 +76,7 @@ public:
|
|||
/// return None.
|
||||
Optional<unsigned> getUnsignedIntegerValue() const;
|
||||
|
||||
/// For an integer token, return its value as an int64_t. If it doesn't fit,
|
||||
/// For an integer token, return its value as an uint64_t. If it doesn't fit,
|
||||
/// return None.
|
||||
Optional<uint64_t> getUInt64IntegerValue() const;
|
||||
|
||||
|
|
|
@ -104,7 +104,9 @@ TOK_KEYWORD(mlfunc)
|
|||
TOK_KEYWORD(mod)
|
||||
TOK_KEYWORD(return)
|
||||
TOK_KEYWORD(size)
|
||||
TOK_KEYWORD(step)
|
||||
TOK_KEYWORD(tensor)
|
||||
TOK_KEYWORD(to)
|
||||
TOK_KEYWORD(true)
|
||||
TOK_KEYWORD(vector)
|
||||
|
||||
|
|
|
@ -130,12 +130,24 @@ extfunc @illegaltype(i0) // expected-error {{invalid integer width}}
|
|||
|
||||
// -----
|
||||
|
||||
mlfunc @malformed_for() {
|
||||
for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mlfunc @incomplete_for() {
|
||||
for
|
||||
for %i = 1 to 10 step 2
|
||||
} // expected-error {{expected '{' before statement list}}
|
||||
|
||||
// -----
|
||||
|
||||
mlfunc @nonconstant_step(%1 : i32) {
|
||||
for %2 = 1 to 5 step %1 { // expected-error {{expected non-negative integer for now}}
|
||||
|
||||
// -----
|
||||
|
||||
mlfunc @non_statement() {
|
||||
asd // expected-error {{expected operation name in quotes}}
|
||||
}
|
||||
|
@ -160,7 +172,6 @@ bb40:
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @redef() {
|
||||
|
@ -168,4 +179,16 @@ bb42:
|
|||
%x = "dim"(){index: 0} : ()->i32
|
||||
%x = "dim"(){index: 0} : ()->i32 // expected-error {{redefinition of SSA value %x}}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mlfunc @missing_rbrace() {
|
||||
return %a
|
||||
mlfunc @d {return} // expected-error {{expected ',' or '}'}}
|
||||
|
||||
// -----
|
||||
|
||||
mlfunc @malformed_type(%a : intt) { // expected-error {{expected type}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -96,22 +96,32 @@ mlfunc @emptyMLF() {
|
|||
return // CHECK: return
|
||||
} // CHECK: }
|
||||
|
||||
// CHECK-LABEL: mlfunc @mlfunc_with_args(f16) {
|
||||
mlfunc @mlfunc_with_args(%a : f16) {
|
||||
return %a // CHECK: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops() {
|
||||
cfgfunc @cfgfunc_with_ops() {
|
||||
bb0:
|
||||
%t = "getTensor"() : () -> tensor<4x4x?xf32>
|
||||
|
||||
// CHECK: dim xxx, 2 : sometype
|
||||
%a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
|
||||
|
||||
// CHECK: addf xx, yy : sometype
|
||||
"addf"() : () -> ()
|
||||
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @loops() {
|
||||
mlfunc @loops() {
|
||||
for { // CHECK: for {
|
||||
for { // CHECK: for {
|
||||
// CHECK: for x = 1 to 100 step 2 {
|
||||
for %i = 1 to 100 step 2 {
|
||||
// CHECK: for x = 1 to 200 {
|
||||
for %j = 1 to 200 {
|
||||
} // CHECK: }
|
||||
} // CHECK: }
|
||||
return // CHECK: return
|
||||
|
@ -119,14 +129,14 @@ mlfunc @loops() {
|
|||
|
||||
// CHECK-LABEL: mlfunc @ifstmt() {
|
||||
mlfunc @ifstmt() {
|
||||
for { // CHECK for {
|
||||
if () { // CHECK if () {
|
||||
} else if () { // CHECK } else if () {
|
||||
} else { // CHECK } else {
|
||||
} // CHECK }
|
||||
} // CHECK }
|
||||
return // CHECK return
|
||||
} // CHECK }
|
||||
for %i = 1 to 10 { // CHECK for x = 1 to 10 {
|
||||
if () { // CHECK if () {
|
||||
} else if () { // CHECK } else if () {
|
||||
} else { // CHECK } else {
|
||||
} // CHECK }
|
||||
} // CHECK }
|
||||
return // CHECK return
|
||||
} // CHECK }
|
||||
|
||||
// CHECK-LABEL: cfgfunc @attributes() {
|
||||
cfgfunc @attributes() {
|
||||
|
|
Loading…
Reference in New Issue