forked from OSchip/llvm-project
NFC: Uniformize parser naming scheme in Toy tutorial to camelCase and tidy a bit of the implementation.
PiperOrigin-RevId: 278982817
This commit is contained in:
parent
f6188b5b07
commit
22cfff7043
|
@ -6,4 +6,7 @@ add_toy_chapter(toyc-ch1
|
|||
toyc.cpp
|
||||
parser/AST.cpp
|
||||
)
|
||||
include_directories(include/)
|
||||
include_directories(include/)
|
||||
target_link_libraries(toyc-ch1
|
||||
PRIVATE
|
||||
MLIRSupport)
|
||||
|
|
|
@ -54,7 +54,6 @@ public:
|
|||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
: kind(kind), location(location) {}
|
||||
|
||||
virtual ~ExprAST() = default;
|
||||
|
||||
ExprASTKind getKind() const { return kind; }
|
||||
|
@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
|
|||
double Val;
|
||||
|
||||
public:
|
||||
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
|
||||
NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
|
||||
|
||||
double getValue() { return Val; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
/// Expression class for a literal value.
|
||||
|
@ -93,10 +92,11 @@ public:
|
|||
: ExprAST(Expr_Literal, loc), values(std::move(values)),
|
||||
dims(std::move(dims)) {}
|
||||
|
||||
std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
|
||||
std::vector<int64_t> &getDims() { return dims; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
|
||||
llvm::ArrayRef<int64_t> getDims() { return dims; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
|
||||
};
|
||||
|
||||
/// Expression class for referencing a variable, like "a".
|
||||
|
@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
|
|||
std::string name;
|
||||
|
||||
public:
|
||||
VariableExprAST(Location loc, const std::string &name)
|
||||
VariableExprAST(Location loc, llvm::StringRef name)
|
||||
: ExprAST(Expr_Var, loc), name(name) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
/// Expression class for defining a variable.
|
||||
|
@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
|
|||
std::unique_ptr<ExprAST> initVal;
|
||||
|
||||
public:
|
||||
VarDeclExprAST(Location loc, const std::string &name, VarType type,
|
||||
VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
|
||||
std::unique_ptr<ExprAST> initVal)
|
||||
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
|
||||
initVal(std::move(initVal)) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
ExprAST *getInitVal() { return initVal.get(); }
|
||||
VarType &getType() { return type; }
|
||||
const VarType &getType() { return type; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
/// Expression class for a return operator.
|
||||
|
@ -144,61 +144,61 @@ public:
|
|||
llvm::Optional<ExprAST *> getExpr() {
|
||||
if (expr.hasValue())
|
||||
return expr->get();
|
||||
return llvm::NoneType();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
|
||||
};
|
||||
|
||||
/// Expression class for a binary operator.
|
||||
class BinaryExprAST : public ExprAST {
|
||||
char Op;
|
||||
std::unique_ptr<ExprAST> LHS, RHS;
|
||||
char op;
|
||||
std::unique_ptr<ExprAST> lhs, rhs;
|
||||
|
||||
public:
|
||||
char getOp() { return Op; }
|
||||
ExprAST *getLHS() { return LHS.get(); }
|
||||
ExprAST *getRHS() { return RHS.get(); }
|
||||
char getOp() { return op; }
|
||||
ExprAST *getLHS() { return lhs.get(); }
|
||||
ExprAST *getRHS() { return rhs.get(); }
|
||||
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
|
||||
std::unique_ptr<ExprAST> RHS)
|
||||
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
|
||||
RHS(std::move(RHS)) {}
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
|
||||
std::unique_ptr<ExprAST> rhs)
|
||||
: ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
|
||||
rhs(std::move(rhs)) {}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
|
||||
};
|
||||
|
||||
/// Expression class for function calls.
|
||||
class CallExprAST : public ExprAST {
|
||||
std::string Callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::string callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
|
||||
public:
|
||||
CallExprAST(Location loc, const std::string &Callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> Args)
|
||||
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
|
||||
CallExprAST(Location loc, const std::string &callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> args)
|
||||
: ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
|
||||
|
||||
llvm::StringRef getCallee() { return Callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
|
||||
llvm::StringRef getCallee() { return callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
|
||||
};
|
||||
|
||||
/// Expression class for builtin print calls.
|
||||
class PrintExprAST : public ExprAST {
|
||||
std::unique_ptr<ExprAST> Arg;
|
||||
std::unique_ptr<ExprAST> arg;
|
||||
|
||||
public:
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
|
||||
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
|
||||
: ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
|
||||
|
||||
ExprAST *getArg() { return Arg.get(); }
|
||||
ExprAST *getArg() { return arg.get(); }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
|
||||
};
|
||||
|
||||
/// This class represents the "prototype" for a function, which captures its
|
||||
|
@ -215,23 +215,21 @@ public:
|
|||
: location(location), name(name), args(std::move(args)) {}
|
||||
|
||||
const Location &loc() { return location; }
|
||||
const std::string &getName() const { return name; }
|
||||
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
|
||||
return args;
|
||||
}
|
||||
llvm::StringRef getName() const { return name; }
|
||||
llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
|
||||
};
|
||||
|
||||
/// This class represents a function definition itself.
|
||||
class FunctionAST {
|
||||
std::unique_ptr<PrototypeAST> Proto;
|
||||
std::unique_ptr<ExprASTList> Body;
|
||||
std::unique_ptr<PrototypeAST> proto;
|
||||
std::unique_ptr<ExprASTList> body;
|
||||
|
||||
public:
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
|
||||
std::unique_ptr<ExprASTList> Body)
|
||||
: Proto(std::move(Proto)), Body(std::move(Body)) {}
|
||||
PrototypeAST *getProto() { return Proto.get(); }
|
||||
ExprASTList *getBody() { return Body.get(); }
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> proto,
|
||||
std::unique_ptr<ExprASTList> body)
|
||||
: proto(std::move(proto)), body(std::move(body)) {}
|
||||
PrototypeAST *getProto() { return proto.get(); }
|
||||
ExprASTList *getBody() { return body.get(); }
|
||||
};
|
||||
|
||||
/// This class represents a list of functions to be processed together
|
||||
|
|
|
@ -89,13 +89,13 @@ public:
|
|||
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
|
||||
llvm::StringRef getId() {
|
||||
assert(curTok == tok_identifier);
|
||||
return IdentifierStr;
|
||||
return identifierStr;
|
||||
}
|
||||
|
||||
/// Return the current number (prereq: getCurToken() == tok_number)
|
||||
double getValue() {
|
||||
assert(curTok == tok_number);
|
||||
return NumVal;
|
||||
return numVal;
|
||||
}
|
||||
|
||||
/// Return the location for the beginning of the current token.
|
||||
|
@ -135,56 +135,58 @@ private:
|
|||
/// Return the next token from standard input.
|
||||
Token getTok() {
|
||||
// Skip any whitespace.
|
||||
while (isspace(LastChar))
|
||||
LastChar = Token(getNextChar());
|
||||
while (isspace(lastChar))
|
||||
lastChar = Token(getNextChar());
|
||||
|
||||
// Save the current location before reading the token characters.
|
||||
lastLocation.line = curLineNum;
|
||||
lastLocation.col = curCol;
|
||||
|
||||
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
IdentifierStr = (char)LastChar;
|
||||
while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
|
||||
IdentifierStr += (char)LastChar;
|
||||
// Identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
if (isalpha(lastChar)) {
|
||||
identifierStr = (char)lastChar;
|
||||
while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
|
||||
identifierStr += (char)lastChar;
|
||||
|
||||
if (IdentifierStr == "return")
|
||||
if (identifierStr == "return")
|
||||
return tok_return;
|
||||
if (IdentifierStr == "def")
|
||||
if (identifierStr == "def")
|
||||
return tok_def;
|
||||
if (IdentifierStr == "var")
|
||||
if (identifierStr == "var")
|
||||
return tok_var;
|
||||
return tok_identifier;
|
||||
}
|
||||
|
||||
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
|
||||
std::string NumStr;
|
||||
// Number: [0-9.]+
|
||||
if (isdigit(lastChar) || lastChar == '.') {
|
||||
std::string numStr;
|
||||
do {
|
||||
NumStr += LastChar;
|
||||
LastChar = Token(getNextChar());
|
||||
} while (isdigit(LastChar) || LastChar == '.');
|
||||
numStr += lastChar;
|
||||
lastChar = Token(getNextChar());
|
||||
} while (isdigit(lastChar) || lastChar == '.');
|
||||
|
||||
NumVal = strtod(NumStr.c_str(), nullptr);
|
||||
numVal = strtod(numStr.c_str(), nullptr);
|
||||
return tok_number;
|
||||
}
|
||||
|
||||
if (LastChar == '#') {
|
||||
if (lastChar == '#') {
|
||||
// Comment until end of line.
|
||||
do
|
||||
LastChar = Token(getNextChar());
|
||||
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
|
||||
do {
|
||||
lastChar = Token(getNextChar());
|
||||
} while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
|
||||
|
||||
if (LastChar != EOF)
|
||||
if (lastChar != EOF)
|
||||
return getTok();
|
||||
}
|
||||
|
||||
// Check for end of file. Don't eat the EOF.
|
||||
if (LastChar == EOF)
|
||||
if (lastChar == EOF)
|
||||
return tok_eof;
|
||||
|
||||
// Otherwise, just return the character as its ascii value.
|
||||
Token ThisChar = Token(LastChar);
|
||||
LastChar = Token(getNextChar());
|
||||
return ThisChar;
|
||||
Token thisChar = Token(lastChar);
|
||||
lastChar = Token(getNextChar());
|
||||
return thisChar;
|
||||
}
|
||||
|
||||
/// The last token read from the input.
|
||||
|
@ -194,15 +196,15 @@ private:
|
|||
Location lastLocation;
|
||||
|
||||
/// If the current Token is an identifier, this string contains the value.
|
||||
std::string IdentifierStr;
|
||||
std::string identifierStr;
|
||||
|
||||
/// If the current Token is a number, this contains the value.
|
||||
double NumVal = 0;
|
||||
double numVal = 0;
|
||||
|
||||
/// The last value returned by getNextChar(). We need to keep it around as we
|
||||
/// always need to read ahead one character to decide when to end a token and
|
||||
/// we can't put it back in the stream after reading from it.
|
||||
Token LastChar = Token(' ');
|
||||
Token lastChar = Token(' ');
|
||||
|
||||
/// Keep track of the current line number in the input stream
|
||||
int curLineNum = 0;
|
||||
|
|
|
@ -48,13 +48,13 @@ public:
|
|||
Parser(Lexer &lexer) : lexer(lexer) {}
|
||||
|
||||
/// Parse a full Module. A module is a list of function definitions.
|
||||
std::unique_ptr<ModuleAST> ParseModule() {
|
||||
std::unique_ptr<ModuleAST> parseModule() {
|
||||
lexer.getNextToken(); // prime the lexer
|
||||
|
||||
// Parse functions one at a time and accumulate in this vector.
|
||||
std::vector<FunctionAST> functions;
|
||||
while (auto F = ParseDefinition()) {
|
||||
functions.push_back(std::move(*F));
|
||||
while (auto f = parseDefinition()) {
|
||||
functions.push_back(std::move(*f));
|
||||
if (lexer.getCurToken() == tok_eof)
|
||||
break;
|
||||
}
|
||||
|
@ -70,14 +70,14 @@ private:
|
|||
|
||||
/// Parse a return statement.
|
||||
/// return :== return ; | return expr ;
|
||||
std::unique_ptr<ReturnExprAST> ParseReturn() {
|
||||
std::unique_ptr<ReturnExprAST> parseReturn() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_return);
|
||||
|
||||
// return takes an optional argument
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
if (lexer.getCurToken() != ';') {
|
||||
expr = ParseExpression();
|
||||
expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -86,18 +86,18 @@ private:
|
|||
|
||||
/// Parse a literal number.
|
||||
/// numberexpr ::= number
|
||||
std::unique_ptr<ExprAST> ParseNumberExpr() {
|
||||
std::unique_ptr<ExprAST> parseNumberExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
auto Result =
|
||||
auto result =
|
||||
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
|
||||
lexer.consume(tok_number);
|
||||
return std::move(Result);
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/// Parse a literal array expression.
|
||||
/// tensorLiteral ::= [ literalList ] | number
|
||||
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
||||
std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
|
||||
std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(Token('['));
|
||||
|
||||
|
@ -108,13 +108,13 @@ private:
|
|||
do {
|
||||
// We can have either another nested array or a number literal.
|
||||
if (lexer.getCurToken() == '[') {
|
||||
values.push_back(ParseTensorLiteralExpr());
|
||||
values.push_back(parseTensorLiteralExpr());
|
||||
if (!values.back())
|
||||
return nullptr; // parse error in the nested array.
|
||||
} else {
|
||||
if (lexer.getCurToken() != tok_number)
|
||||
return parseError<ExprAST>("<num> or [", "in literal expression");
|
||||
values.push_back(ParseNumberExpr());
|
||||
values.push_back(parseNumberExpr());
|
||||
}
|
||||
|
||||
// End of this list on ']'
|
||||
|
@ -130,8 +130,10 @@ private:
|
|||
if (values.empty())
|
||||
return parseError<ExprAST>("<something>", "to fill literal expression");
|
||||
lexer.getNextToken(); // eat ]
|
||||
|
||||
/// Fill in the dimensions now. First the current nesting level:
|
||||
dims.push_back(values.size());
|
||||
|
||||
/// If there is any nested array, process all of them and ensure that
|
||||
/// dimensions are uniform.
|
||||
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
|
||||
|
@ -143,7 +145,7 @@ private:
|
|||
"inside literal expression");
|
||||
|
||||
// Append the nested dimensions to the current level
|
||||
auto &firstDims = firstLiteral->getDims();
|
||||
auto firstDims = firstLiteral->getDims();
|
||||
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
||||
|
||||
// Sanity check that shape is uniform across all elements of the list.
|
||||
|
@ -162,22 +164,22 @@ private:
|
|||
}
|
||||
|
||||
/// parenexpr ::= '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseParenExpr() {
|
||||
std::unique_ptr<ExprAST> parseParenExpr() {
|
||||
lexer.getNextToken(); // eat (.
|
||||
auto V = ParseExpression();
|
||||
if (!V)
|
||||
auto v = parseExpression();
|
||||
if (!v)
|
||||
return nullptr;
|
||||
|
||||
if (lexer.getCurToken() != ')')
|
||||
return parseError<ExprAST>(")", "to close expression with parentheses");
|
||||
lexer.consume(Token(')'));
|
||||
return V;
|
||||
return v;
|
||||
}
|
||||
|
||||
/// identifierexpr
|
||||
/// ::= identifier
|
||||
/// ::= identifier '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
|
||||
std::unique_ptr<ExprAST> parseIdentifierExpr() {
|
||||
std::string name = lexer.getId();
|
||||
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -188,11 +190,11 @@ private:
|
|||
|
||||
// This is a function call.
|
||||
lexer.consume(Token('('));
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
if (lexer.getCurToken() != ')') {
|
||||
while (true) {
|
||||
if (auto Arg = ParseExpression())
|
||||
Args.push_back(std::move(Arg));
|
||||
if (auto arg = parseExpression())
|
||||
args.push_back(std::move(arg));
|
||||
else
|
||||
return nullptr;
|
||||
|
||||
|
@ -208,14 +210,14 @@ private:
|
|||
|
||||
// It can be a builtin call to print
|
||||
if (name == "print") {
|
||||
if (Args.size() != 1)
|
||||
if (args.size() != 1)
|
||||
return parseError<ExprAST>("<single arg>", "as argument to print()");
|
||||
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
|
||||
}
|
||||
|
||||
// Call to a user-defined function
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
|
||||
}
|
||||
|
||||
/// primary
|
||||
|
@ -223,20 +225,20 @@ private:
|
|||
/// ::= numberexpr
|
||||
/// ::= parenexpr
|
||||
/// ::= tensorliteral
|
||||
std::unique_ptr<ExprAST> ParsePrimary() {
|
||||
std::unique_ptr<ExprAST> parsePrimary() {
|
||||
switch (lexer.getCurToken()) {
|
||||
default:
|
||||
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
||||
<< "' when expecting an expression\n";
|
||||
return nullptr;
|
||||
case tok_identifier:
|
||||
return ParseIdentifierExpr();
|
||||
return parseIdentifierExpr();
|
||||
case tok_number:
|
||||
return ParseNumberExpr();
|
||||
return parseNumberExpr();
|
||||
case '(':
|
||||
return ParseParenExpr();
|
||||
return parseParenExpr();
|
||||
case '[':
|
||||
return ParseTensorLiteralExpr();
|
||||
return parseTensorLiteralExpr();
|
||||
case ';':
|
||||
return nullptr;
|
||||
case '}':
|
||||
|
@ -248,54 +250,54 @@ private:
|
|||
/// argument indicates the precedence of the current binary operator.
|
||||
///
|
||||
/// binoprhs ::= ('+' primary)*
|
||||
std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
|
||||
std::unique_ptr<ExprAST> LHS) {
|
||||
std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
|
||||
std::unique_ptr<ExprAST> lhs) {
|
||||
// If this is a binop, find its precedence.
|
||||
while (true) {
|
||||
int TokPrec = GetTokPrecedence();
|
||||
int tokPrec = getTokPrecedence();
|
||||
|
||||
// If this is a binop that binds at least as tightly as the current binop,
|
||||
// consume it, otherwise we are done.
|
||||
if (TokPrec < ExprPrec)
|
||||
return LHS;
|
||||
if (tokPrec < exprPrec)
|
||||
return lhs;
|
||||
|
||||
// Okay, we know this is a binop.
|
||||
int BinOp = lexer.getCurToken();
|
||||
lexer.consume(Token(BinOp));
|
||||
int binOp = lexer.getCurToken();
|
||||
lexer.consume(Token(binOp));
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
||||
// Parse the primary expression after the binary operator.
|
||||
auto RHS = ParsePrimary();
|
||||
if (!RHS)
|
||||
auto rhs = parsePrimary();
|
||||
if (!rhs)
|
||||
return parseError<ExprAST>("expression", "to complete binary operator");
|
||||
|
||||
// If BinOp binds less tightly with RHS than the operator after RHS, let
|
||||
// the pending operator take RHS as its LHS.
|
||||
int NextPrec = GetTokPrecedence();
|
||||
if (TokPrec < NextPrec) {
|
||||
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
|
||||
if (!RHS)
|
||||
// If BinOp binds less tightly with rhs than the operator after rhs, let
|
||||
// the pending operator take rhs as its lhs.
|
||||
int nextPrec = getTokPrecedence();
|
||||
if (tokPrec < nextPrec) {
|
||||
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Merge LHS/RHS.
|
||||
LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
|
||||
std::move(LHS), std::move(RHS));
|
||||
// Merge lhs/RHS.
|
||||
lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
|
||||
std::move(lhs), std::move(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
/// expression::= primary binoprhs
|
||||
std::unique_ptr<ExprAST> ParseExpression() {
|
||||
auto LHS = ParsePrimary();
|
||||
if (!LHS)
|
||||
/// expression::= primary binop rhs
|
||||
std::unique_ptr<ExprAST> parseExpression() {
|
||||
auto lhs = parsePrimary();
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
|
||||
return ParseBinOpRHS(0, std::move(LHS));
|
||||
return parseBinOpRHS(0, std::move(lhs));
|
||||
}
|
||||
|
||||
/// type ::= < shape_list >
|
||||
/// shape_list ::= num | num , shape_list
|
||||
std::unique_ptr<VarType> ParseType() {
|
||||
std::unique_ptr<VarType> parseType() {
|
||||
if (lexer.getCurToken() != '<')
|
||||
return parseError<VarType>("<", "to begin type");
|
||||
lexer.getNextToken(); // eat <
|
||||
|
@ -319,7 +321,7 @@ private:
|
|||
/// and identifier and an optional type (shape specification) before the
|
||||
/// initializer.
|
||||
/// decl ::= var identifier [ type ] = expr
|
||||
std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
|
||||
std::unique_ptr<VarDeclExprAST> parseDeclaration() {
|
||||
if (lexer.getCurToken() != tok_var)
|
||||
return parseError<VarDeclExprAST>("var", "to begin declaration");
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -333,7 +335,7 @@ private:
|
|||
|
||||
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
|
||||
if (lexer.getCurToken() == '<') {
|
||||
type = ParseType();
|
||||
type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -341,7 +343,7 @@ private:
|
|||
if (!type)
|
||||
type = std::make_unique<VarType>();
|
||||
lexer.consume(Token('='));
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
|
||||
std::move(*type), std::move(expr));
|
||||
}
|
||||
|
@ -352,7 +354,7 @@ private:
|
|||
/// block ::= { expression_list }
|
||||
/// expression_list ::= block_expr ; expression_list
|
||||
/// block_expr ::= decl | "return" | expr
|
||||
std::unique_ptr<ExprASTList> ParseBlock() {
|
||||
std::unique_ptr<ExprASTList> parseBlock() {
|
||||
if (lexer.getCurToken() != '{')
|
||||
return parseError<ExprASTList>("{", "to begin block");
|
||||
lexer.consume(Token('{'));
|
||||
|
@ -366,19 +368,19 @@ private:
|
|||
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
|
||||
if (lexer.getCurToken() == tok_var) {
|
||||
// Variable declaration
|
||||
auto varDecl = ParseDeclaration();
|
||||
auto varDecl = parseDeclaration();
|
||||
if (!varDecl)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(varDecl));
|
||||
} else if (lexer.getCurToken() == tok_return) {
|
||||
// Return statement
|
||||
auto ret = ParseReturn();
|
||||
auto ret = parseReturn();
|
||||
if (!ret)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(ret));
|
||||
} else {
|
||||
// General expression
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(expr));
|
||||
|
@ -401,13 +403,13 @@ private:
|
|||
|
||||
/// prototype ::= def id '(' decl_list ')'
|
||||
/// decl_list ::= identifier | identifier, decl_list
|
||||
std::unique_ptr<PrototypeAST> ParsePrototype() {
|
||||
std::unique_ptr<PrototypeAST> parsePrototype() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_def);
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<PrototypeAST>("function name", "in prototype");
|
||||
|
||||
std::string FnName = lexer.getId();
|
||||
std::string fnName = lexer.getId();
|
||||
lexer.consume(tok_identifier);
|
||||
|
||||
if (lexer.getCurToken() != '(')
|
||||
|
@ -435,7 +437,7 @@ private:
|
|||
|
||||
// success.
|
||||
lexer.consume(Token(')'));
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), FnName,
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), fnName,
|
||||
std::move(args));
|
||||
}
|
||||
|
||||
|
@ -443,18 +445,18 @@ private:
|
|||
/// `def` keyword, followed by a block containing a list of expressions.
|
||||
///
|
||||
/// definition ::= prototype block
|
||||
std::unique_ptr<FunctionAST> ParseDefinition() {
|
||||
auto Proto = ParsePrototype();
|
||||
if (!Proto)
|
||||
std::unique_ptr<FunctionAST> parseDefinition() {
|
||||
auto proto = parsePrototype();
|
||||
if (!proto)
|
||||
return nullptr;
|
||||
|
||||
if (auto block = ParseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
|
||||
if (auto block = parseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Get the precedence of the pending binary operator token.
|
||||
int GetTokPrecedence() {
|
||||
int getTokPrecedence() {
|
||||
if (!isascii(lexer.getCurToken()))
|
||||
return -1;
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "toy/AST.h"
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -40,22 +41,22 @@ struct Indent {
|
|||
/// the way. The only data member is the current indentation level.
|
||||
class ASTDumper {
|
||||
public:
|
||||
void dump(ModuleAST *Node);
|
||||
void dump(ModuleAST *node);
|
||||
|
||||
private:
|
||||
void dump(VarType &type);
|
||||
void dump(const VarType &type);
|
||||
void dump(VarDeclExprAST *varDecl);
|
||||
void dump(ExprAST *expr);
|
||||
void dump(ExprASTList *exprList);
|
||||
void dump(NumberExprAST *num);
|
||||
void dump(LiteralExprAST *Node);
|
||||
void dump(VariableExprAST *Node);
|
||||
void dump(ReturnExprAST *Node);
|
||||
void dump(BinaryExprAST *Node);
|
||||
void dump(CallExprAST *Node);
|
||||
void dump(PrintExprAST *Node);
|
||||
void dump(PrototypeAST *Node);
|
||||
void dump(FunctionAST *Node);
|
||||
void dump(LiteralExprAST *node);
|
||||
void dump(VariableExprAST *node);
|
||||
void dump(ReturnExprAST *node);
|
||||
void dump(BinaryExprAST *node);
|
||||
void dump(CallExprAST *node);
|
||||
void dump(PrintExprAST *node);
|
||||
void dump(PrototypeAST *node);
|
||||
void dump(FunctionAST *node);
|
||||
|
||||
// Actually print spaces matching the current indentation level
|
||||
void indent() {
|
||||
|
@ -68,8 +69,8 @@ private:
|
|||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *Node) {
|
||||
const auto &loc = Node->loc();
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
.str();
|
||||
|
@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
|
|||
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||
/// We print out such array with the dimensions spelled out at every level:
|
||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||
void printLitHelper(ExprAST *lit_or_num) {
|
||||
void printLitHelper(ExprAST *litOrNum) {
|
||||
// Inside a literal expression we can have either a number or another literal
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||
llvm::errs() << num->getValue();
|
||||
return;
|
||||
}
|
||||
auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
|
||||
auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
|
||||
|
||||
// Print the dimension for this literal first
|
||||
llvm::errs() << "<";
|
||||
{
|
||||
const char *sep = "";
|
||||
for (auto dim : literal->getDims()) {
|
||||
llvm::errs() << sep << dim;
|
||||
sep = ", ";
|
||||
}
|
||||
}
|
||||
mlir::interleaveComma(literal->getDims(), llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
|
||||
// Now print the content, recursing on every element of the list
|
||||
llvm::errs() << "[ ";
|
||||
const char *sep = "";
|
||||
for (auto &elt : literal->getValues()) {
|
||||
llvm::errs() << sep;
|
||||
printLitHelper(elt.get());
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(literal->getValues(), llvm::errs(),
|
||||
[&](auto &elt) { printLitHelper(elt.get()); });
|
||||
llvm::errs() << "]";
|
||||
}
|
||||
|
||||
/// Print a literal, see the recursive helper above for the implementation.
|
||||
void ASTDumper::dump(LiteralExprAST *Node) {
|
||||
void ASTDumper::dump(LiteralExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Literal: ";
|
||||
printLitHelper(Node);
|
||||
llvm::errs() << " " << loc(Node) << "\n";
|
||||
printLitHelper(node);
|
||||
llvm::errs() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Print a variable reference (just a name).
|
||||
void ASTDumper::dump(VariableExprAST *Node) {
|
||||
void ASTDumper::dump(VariableExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
|
||||
llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Return statement print the return and its (optional) argument.
|
||||
void ASTDumper::dump(ReturnExprAST *Node) {
|
||||
void ASTDumper::dump(ReturnExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Return\n";
|
||||
if (Node->getExpr().hasValue())
|
||||
return dump(*Node->getExpr());
|
||||
if (node->getExpr().hasValue())
|
||||
return dump(*node->getExpr());
|
||||
{
|
||||
INDENT();
|
||||
llvm::errs() << "(void)\n";
|
||||
|
@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
|
|||
}
|
||||
|
||||
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
|
||||
void ASTDumper::dump(BinaryExprAST *Node) {
|
||||
void ASTDumper::dump(BinaryExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
|
||||
dump(Node->getLHS());
|
||||
dump(Node->getRHS());
|
||||
llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
|
||||
dump(node->getLHS());
|
||||
dump(node->getRHS());
|
||||
}
|
||||
|
||||
/// Print a call expression, first the callee name and the list of args by
|
||||
/// recursing into each individual argument.
|
||||
void ASTDumper::dump(CallExprAST *Node) {
|
||||
void ASTDumper::dump(CallExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
|
||||
for (auto &arg : Node->getArgs())
|
||||
llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
|
||||
for (auto &arg : node->getArgs())
|
||||
dump(arg.get());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a builtin print call, first the builtin name and then the argument.
|
||||
void ASTDumper::dump(PrintExprAST *Node) {
|
||||
void ASTDumper::dump(PrintExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Print [ " << loc(Node) << "\n";
|
||||
dump(Node->getArg());
|
||||
llvm::errs() << "Print [ " << loc(node) << "\n";
|
||||
dump(node->getArg());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print type: only the shape is printed in between '<' and '>'
|
||||
void ASTDumper::dump(VarType &type) {
|
||||
void ASTDumper::dump(const VarType &type) {
|
||||
llvm::errs() << "<";
|
||||
const char *sep = "";
|
||||
for (auto shape : type.shape) {
|
||||
llvm::errs() << sep << shape;
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(type.shape, llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
}
|
||||
|
||||
/// Print a function prototype, first the function name, and then the list of
|
||||
/// parameters names.
|
||||
void ASTDumper::dump(PrototypeAST *Node) {
|
||||
void ASTDumper::dump(PrototypeAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
|
||||
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
|
||||
indent();
|
||||
llvm::errs() << "Params: [";
|
||||
const char *sep = "";
|
||||
for (auto &arg : Node->getArgs()) {
|
||||
llvm::errs() << sep << arg->getName();
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(node->getArgs(), llvm::errs(),
|
||||
[](auto &arg) { llvm::errs() << arg->getName(); });
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a function, first the prototype and then the body.
|
||||
void ASTDumper::dump(FunctionAST *Node) {
|
||||
void ASTDumper::dump(FunctionAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Function \n";
|
||||
dump(Node->getProto());
|
||||
dump(Node->getBody());
|
||||
dump(node->getProto());
|
||||
dump(node->getBody());
|
||||
}
|
||||
|
||||
/// Print a module, actually loop over the functions and print them in sequence.
|
||||
void ASTDumper::dump(ModuleAST *Node) {
|
||||
void ASTDumper::dump(ModuleAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Module:\n";
|
||||
for (auto &F : *Node)
|
||||
dump(&F);
|
||||
for (auto &f : *node)
|
||||
dump(&f);
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
using namespace toy;
|
||||
namespace cl = llvm::cl;
|
||||
|
||||
static cl::opt<std::string> InputFilename(cl::Positional,
|
||||
static cl::opt<std::string> inputFilename(cl::Positional,
|
||||
cl::desc("<input toy file>"),
|
||||
cl::init("-"),
|
||||
cl::value_desc("filename"));
|
||||
|
@ -44,22 +44,22 @@ static cl::opt<enum Action>
|
|||
|
||||
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
||||
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||
if (std::error_code EC = FileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
if (std::error_code ec = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
auto buffer = FileOrErr.get()->getBuffer();
|
||||
auto buffer = fileOrErr.get()->getBuffer();
|
||||
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
|
||||
Parser parser(lexer);
|
||||
return parser.ParseModule();
|
||||
return parser.parseModule();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
|
||||
|
||||
auto moduleAST = parseInputFile(InputFilename);
|
||||
auto moduleAST = parseInputFile(inputFilename);
|
||||
if (!moduleAST)
|
||||
return 1;
|
||||
|
||||
|
|
|
@ -54,7 +54,6 @@ public:
|
|||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
: kind(kind), location(location) {}
|
||||
|
||||
virtual ~ExprAST() = default;
|
||||
|
||||
ExprASTKind getKind() const { return kind; }
|
||||
|
@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
|
|||
double Val;
|
||||
|
||||
public:
|
||||
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
|
||||
NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
|
||||
|
||||
double getValue() { return Val; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
/// Expression class for a literal value.
|
||||
|
@ -93,10 +92,11 @@ public:
|
|||
: ExprAST(Expr_Literal, loc), values(std::move(values)),
|
||||
dims(std::move(dims)) {}
|
||||
|
||||
std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
|
||||
std::vector<int64_t> &getDims() { return dims; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
|
||||
llvm::ArrayRef<int64_t> getDims() { return dims; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
|
||||
};
|
||||
|
||||
/// Expression class for referencing a variable, like "a".
|
||||
|
@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
|
|||
std::string name;
|
||||
|
||||
public:
|
||||
VariableExprAST(Location loc, const std::string &name)
|
||||
VariableExprAST(Location loc, llvm::StringRef name)
|
||||
: ExprAST(Expr_Var, loc), name(name) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
/// Expression class for defining a variable.
|
||||
|
@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
|
|||
std::unique_ptr<ExprAST> initVal;
|
||||
|
||||
public:
|
||||
VarDeclExprAST(Location loc, const std::string &name, VarType type,
|
||||
VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
|
||||
std::unique_ptr<ExprAST> initVal)
|
||||
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
|
||||
initVal(std::move(initVal)) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
ExprAST *getInitVal() { return initVal.get(); }
|
||||
VarType &getType() { return type; }
|
||||
const VarType &getType() { return type; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
/// Expression class for a return operator.
|
||||
|
@ -144,61 +144,61 @@ public:
|
|||
llvm::Optional<ExprAST *> getExpr() {
|
||||
if (expr.hasValue())
|
||||
return expr->get();
|
||||
return llvm::NoneType();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
|
||||
};
|
||||
|
||||
/// Expression class for a binary operator.
|
||||
class BinaryExprAST : public ExprAST {
|
||||
char Op;
|
||||
std::unique_ptr<ExprAST> LHS, RHS;
|
||||
char op;
|
||||
std::unique_ptr<ExprAST> lhs, rhs;
|
||||
|
||||
public:
|
||||
char getOp() { return Op; }
|
||||
ExprAST *getLHS() { return LHS.get(); }
|
||||
ExprAST *getRHS() { return RHS.get(); }
|
||||
char getOp() { return op; }
|
||||
ExprAST *getLHS() { return lhs.get(); }
|
||||
ExprAST *getRHS() { return rhs.get(); }
|
||||
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
|
||||
std::unique_ptr<ExprAST> RHS)
|
||||
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
|
||||
RHS(std::move(RHS)) {}
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
|
||||
std::unique_ptr<ExprAST> rhs)
|
||||
: ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
|
||||
rhs(std::move(rhs)) {}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
|
||||
};
|
||||
|
||||
/// Expression class for function calls.
|
||||
class CallExprAST : public ExprAST {
|
||||
std::string Callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::string callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
|
||||
public:
|
||||
CallExprAST(Location loc, const std::string &Callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> Args)
|
||||
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
|
||||
CallExprAST(Location loc, const std::string &callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> args)
|
||||
: ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
|
||||
|
||||
llvm::StringRef getCallee() { return Callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
|
||||
llvm::StringRef getCallee() { return callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
|
||||
};
|
||||
|
||||
/// Expression class for builtin print calls.
|
||||
class PrintExprAST : public ExprAST {
|
||||
std::unique_ptr<ExprAST> Arg;
|
||||
std::unique_ptr<ExprAST> arg;
|
||||
|
||||
public:
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
|
||||
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
|
||||
: ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
|
||||
|
||||
ExprAST *getArg() { return Arg.get(); }
|
||||
ExprAST *getArg() { return arg.get(); }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
|
||||
};
|
||||
|
||||
/// This class represents the "prototype" for a function, which captures its
|
||||
|
@ -215,23 +215,21 @@ public:
|
|||
: location(location), name(name), args(std::move(args)) {}
|
||||
|
||||
const Location &loc() { return location; }
|
||||
const std::string &getName() const { return name; }
|
||||
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
|
||||
return args;
|
||||
}
|
||||
llvm::StringRef getName() const { return name; }
|
||||
llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
|
||||
};
|
||||
|
||||
/// This class represents a function definition itself.
|
||||
class FunctionAST {
|
||||
std::unique_ptr<PrototypeAST> Proto;
|
||||
std::unique_ptr<ExprASTList> Body;
|
||||
std::unique_ptr<PrototypeAST> proto;
|
||||
std::unique_ptr<ExprASTList> body;
|
||||
|
||||
public:
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
|
||||
std::unique_ptr<ExprASTList> Body)
|
||||
: Proto(std::move(Proto)), Body(std::move(Body)) {}
|
||||
PrototypeAST *getProto() { return Proto.get(); }
|
||||
ExprASTList *getBody() { return Body.get(); }
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> proto,
|
||||
std::unique_ptr<ExprASTList> body)
|
||||
: proto(std::move(proto)), body(std::move(body)) {}
|
||||
PrototypeAST *getProto() { return proto.get(); }
|
||||
ExprASTList *getBody() { return body.get(); }
|
||||
};
|
||||
|
||||
/// This class represents a list of functions to be processed together
|
||||
|
|
|
@ -89,13 +89,13 @@ public:
|
|||
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
|
||||
llvm::StringRef getId() {
|
||||
assert(curTok == tok_identifier);
|
||||
return IdentifierStr;
|
||||
return identifierStr;
|
||||
}
|
||||
|
||||
/// Return the current number (prereq: getCurToken() == tok_number)
|
||||
double getValue() {
|
||||
assert(curTok == tok_number);
|
||||
return NumVal;
|
||||
return numVal;
|
||||
}
|
||||
|
||||
/// Return the location for the beginning of the current token.
|
||||
|
@ -135,56 +135,58 @@ private:
|
|||
/// Return the next token from standard input.
|
||||
Token getTok() {
|
||||
// Skip any whitespace.
|
||||
while (isspace(LastChar))
|
||||
LastChar = Token(getNextChar());
|
||||
while (isspace(lastChar))
|
||||
lastChar = Token(getNextChar());
|
||||
|
||||
// Save the current location before reading the token characters.
|
||||
lastLocation.line = curLineNum;
|
||||
lastLocation.col = curCol;
|
||||
|
||||
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
IdentifierStr = (char)LastChar;
|
||||
while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
|
||||
IdentifierStr += (char)LastChar;
|
||||
// Identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
if (isalpha(lastChar)) {
|
||||
identifierStr = (char)lastChar;
|
||||
while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
|
||||
identifierStr += (char)lastChar;
|
||||
|
||||
if (IdentifierStr == "return")
|
||||
if (identifierStr == "return")
|
||||
return tok_return;
|
||||
if (IdentifierStr == "def")
|
||||
if (identifierStr == "def")
|
||||
return tok_def;
|
||||
if (IdentifierStr == "var")
|
||||
if (identifierStr == "var")
|
||||
return tok_var;
|
||||
return tok_identifier;
|
||||
}
|
||||
|
||||
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
|
||||
std::string NumStr;
|
||||
// Number: [0-9.]+
|
||||
if (isdigit(lastChar) || lastChar == '.') {
|
||||
std::string numStr;
|
||||
do {
|
||||
NumStr += LastChar;
|
||||
LastChar = Token(getNextChar());
|
||||
} while (isdigit(LastChar) || LastChar == '.');
|
||||
numStr += lastChar;
|
||||
lastChar = Token(getNextChar());
|
||||
} while (isdigit(lastChar) || lastChar == '.');
|
||||
|
||||
NumVal = strtod(NumStr.c_str(), nullptr);
|
||||
numVal = strtod(numStr.c_str(), nullptr);
|
||||
return tok_number;
|
||||
}
|
||||
|
||||
if (LastChar == '#') {
|
||||
if (lastChar == '#') {
|
||||
// Comment until end of line.
|
||||
do
|
||||
LastChar = Token(getNextChar());
|
||||
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
|
||||
do {
|
||||
lastChar = Token(getNextChar());
|
||||
} while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
|
||||
|
||||
if (LastChar != EOF)
|
||||
if (lastChar != EOF)
|
||||
return getTok();
|
||||
}
|
||||
|
||||
// Check for end of file. Don't eat the EOF.
|
||||
if (LastChar == EOF)
|
||||
if (lastChar == EOF)
|
||||
return tok_eof;
|
||||
|
||||
// Otherwise, just return the character as its ascii value.
|
||||
Token ThisChar = Token(LastChar);
|
||||
LastChar = Token(getNextChar());
|
||||
return ThisChar;
|
||||
Token thisChar = Token(lastChar);
|
||||
lastChar = Token(getNextChar());
|
||||
return thisChar;
|
||||
}
|
||||
|
||||
/// The last token read from the input.
|
||||
|
@ -194,15 +196,15 @@ private:
|
|||
Location lastLocation;
|
||||
|
||||
/// If the current Token is an identifier, this string contains the value.
|
||||
std::string IdentifierStr;
|
||||
std::string identifierStr;
|
||||
|
||||
/// If the current Token is a number, this contains the value.
|
||||
double NumVal = 0;
|
||||
double numVal = 0;
|
||||
|
||||
/// The last value returned by getNextChar(). We need to keep it around as we
|
||||
/// always need to read ahead one character to decide when to end a token and
|
||||
/// we can't put it back in the stream after reading from it.
|
||||
Token LastChar = Token(' ');
|
||||
Token lastChar = Token(' ');
|
||||
|
||||
/// Keep track of the current line number in the input stream
|
||||
int curLineNum = 0;
|
||||
|
|
|
@ -48,13 +48,13 @@ public:
|
|||
Parser(Lexer &lexer) : lexer(lexer) {}
|
||||
|
||||
/// Parse a full Module. A module is a list of function definitions.
|
||||
std::unique_ptr<ModuleAST> ParseModule() {
|
||||
std::unique_ptr<ModuleAST> parseModule() {
|
||||
lexer.getNextToken(); // prime the lexer
|
||||
|
||||
// Parse functions one at a time and accumulate in this vector.
|
||||
std::vector<FunctionAST> functions;
|
||||
while (auto F = ParseDefinition()) {
|
||||
functions.push_back(std::move(*F));
|
||||
while (auto f = parseDefinition()) {
|
||||
functions.push_back(std::move(*f));
|
||||
if (lexer.getCurToken() == tok_eof)
|
||||
break;
|
||||
}
|
||||
|
@ -70,14 +70,14 @@ private:
|
|||
|
||||
/// Parse a return statement.
|
||||
/// return :== return ; | return expr ;
|
||||
std::unique_ptr<ReturnExprAST> ParseReturn() {
|
||||
std::unique_ptr<ReturnExprAST> parseReturn() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_return);
|
||||
|
||||
// return takes an optional argument
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
if (lexer.getCurToken() != ';') {
|
||||
expr = ParseExpression();
|
||||
expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -86,18 +86,18 @@ private:
|
|||
|
||||
/// Parse a literal number.
|
||||
/// numberexpr ::= number
|
||||
std::unique_ptr<ExprAST> ParseNumberExpr() {
|
||||
std::unique_ptr<ExprAST> parseNumberExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
auto Result =
|
||||
auto result =
|
||||
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
|
||||
lexer.consume(tok_number);
|
||||
return std::move(Result);
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/// Parse a literal array expression.
|
||||
/// tensorLiteral ::= [ literalList ] | number
|
||||
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
||||
std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
|
||||
std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(Token('['));
|
||||
|
||||
|
@ -108,13 +108,13 @@ private:
|
|||
do {
|
||||
// We can have either another nested array or a number literal.
|
||||
if (lexer.getCurToken() == '[') {
|
||||
values.push_back(ParseTensorLiteralExpr());
|
||||
values.push_back(parseTensorLiteralExpr());
|
||||
if (!values.back())
|
||||
return nullptr; // parse error in the nested array.
|
||||
} else {
|
||||
if (lexer.getCurToken() != tok_number)
|
||||
return parseError<ExprAST>("<num> or [", "in literal expression");
|
||||
values.push_back(ParseNumberExpr());
|
||||
values.push_back(parseNumberExpr());
|
||||
}
|
||||
|
||||
// End of this list on ']'
|
||||
|
@ -130,8 +130,10 @@ private:
|
|||
if (values.empty())
|
||||
return parseError<ExprAST>("<something>", "to fill literal expression");
|
||||
lexer.getNextToken(); // eat ]
|
||||
|
||||
/// Fill in the dimensions now. First the current nesting level:
|
||||
dims.push_back(values.size());
|
||||
|
||||
/// If there is any nested array, process all of them and ensure that
|
||||
/// dimensions are uniform.
|
||||
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
|
||||
|
@ -143,7 +145,7 @@ private:
|
|||
"inside literal expression");
|
||||
|
||||
// Append the nested dimensions to the current level
|
||||
auto &firstDims = firstLiteral->getDims();
|
||||
auto firstDims = firstLiteral->getDims();
|
||||
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
||||
|
||||
// Sanity check that shape is uniform across all elements of the list.
|
||||
|
@ -162,22 +164,22 @@ private:
|
|||
}
|
||||
|
||||
/// parenexpr ::= '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseParenExpr() {
|
||||
std::unique_ptr<ExprAST> parseParenExpr() {
|
||||
lexer.getNextToken(); // eat (.
|
||||
auto V = ParseExpression();
|
||||
if (!V)
|
||||
auto v = parseExpression();
|
||||
if (!v)
|
||||
return nullptr;
|
||||
|
||||
if (lexer.getCurToken() != ')')
|
||||
return parseError<ExprAST>(")", "to close expression with parentheses");
|
||||
lexer.consume(Token(')'));
|
||||
return V;
|
||||
return v;
|
||||
}
|
||||
|
||||
/// identifierexpr
|
||||
/// ::= identifier
|
||||
/// ::= identifier '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
|
||||
std::unique_ptr<ExprAST> parseIdentifierExpr() {
|
||||
std::string name = lexer.getId();
|
||||
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -188,11 +190,11 @@ private:
|
|||
|
||||
// This is a function call.
|
||||
lexer.consume(Token('('));
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
if (lexer.getCurToken() != ')') {
|
||||
while (true) {
|
||||
if (auto Arg = ParseExpression())
|
||||
Args.push_back(std::move(Arg));
|
||||
if (auto arg = parseExpression())
|
||||
args.push_back(std::move(arg));
|
||||
else
|
||||
return nullptr;
|
||||
|
||||
|
@ -208,14 +210,14 @@ private:
|
|||
|
||||
// It can be a builtin call to print
|
||||
if (name == "print") {
|
||||
if (Args.size() != 1)
|
||||
if (args.size() != 1)
|
||||
return parseError<ExprAST>("<single arg>", "as argument to print()");
|
||||
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
|
||||
}
|
||||
|
||||
// Call to a user-defined function
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
|
||||
}
|
||||
|
||||
/// primary
|
||||
|
@ -223,20 +225,20 @@ private:
|
|||
/// ::= numberexpr
|
||||
/// ::= parenexpr
|
||||
/// ::= tensorliteral
|
||||
std::unique_ptr<ExprAST> ParsePrimary() {
|
||||
std::unique_ptr<ExprAST> parsePrimary() {
|
||||
switch (lexer.getCurToken()) {
|
||||
default:
|
||||
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
||||
<< "' when expecting an expression\n";
|
||||
return nullptr;
|
||||
case tok_identifier:
|
||||
return ParseIdentifierExpr();
|
||||
return parseIdentifierExpr();
|
||||
case tok_number:
|
||||
return ParseNumberExpr();
|
||||
return parseNumberExpr();
|
||||
case '(':
|
||||
return ParseParenExpr();
|
||||
return parseParenExpr();
|
||||
case '[':
|
||||
return ParseTensorLiteralExpr();
|
||||
return parseTensorLiteralExpr();
|
||||
case ';':
|
||||
return nullptr;
|
||||
case '}':
|
||||
|
@ -248,54 +250,54 @@ private:
|
|||
/// argument indicates the precedence of the current binary operator.
|
||||
///
|
||||
/// binoprhs ::= ('+' primary)*
|
||||
std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
|
||||
std::unique_ptr<ExprAST> LHS) {
|
||||
std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
|
||||
std::unique_ptr<ExprAST> lhs) {
|
||||
// If this is a binop, find its precedence.
|
||||
while (true) {
|
||||
int TokPrec = GetTokPrecedence();
|
||||
int tokPrec = getTokPrecedence();
|
||||
|
||||
// If this is a binop that binds at least as tightly as the current binop,
|
||||
// consume it, otherwise we are done.
|
||||
if (TokPrec < ExprPrec)
|
||||
return LHS;
|
||||
if (tokPrec < exprPrec)
|
||||
return lhs;
|
||||
|
||||
// Okay, we know this is a binop.
|
||||
int BinOp = lexer.getCurToken();
|
||||
lexer.consume(Token(BinOp));
|
||||
int binOp = lexer.getCurToken();
|
||||
lexer.consume(Token(binOp));
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
||||
// Parse the primary expression after the binary operator.
|
||||
auto RHS = ParsePrimary();
|
||||
if (!RHS)
|
||||
auto rhs = parsePrimary();
|
||||
if (!rhs)
|
||||
return parseError<ExprAST>("expression", "to complete binary operator");
|
||||
|
||||
// If BinOp binds less tightly with RHS than the operator after RHS, let
|
||||
// the pending operator take RHS as its LHS.
|
||||
int NextPrec = GetTokPrecedence();
|
||||
if (TokPrec < NextPrec) {
|
||||
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
|
||||
if (!RHS)
|
||||
// If BinOp binds less tightly with rhs than the operator after rhs, let
|
||||
// the pending operator take rhs as its lhs.
|
||||
int nextPrec = getTokPrecedence();
|
||||
if (tokPrec < nextPrec) {
|
||||
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Merge LHS/RHS.
|
||||
LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
|
||||
std::move(LHS), std::move(RHS));
|
||||
// Merge lhs/RHS.
|
||||
lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
|
||||
std::move(lhs), std::move(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
/// expression::= primary binoprhs
|
||||
std::unique_ptr<ExprAST> ParseExpression() {
|
||||
auto LHS = ParsePrimary();
|
||||
if (!LHS)
|
||||
/// expression::= primary binop rhs
|
||||
std::unique_ptr<ExprAST> parseExpression() {
|
||||
auto lhs = parsePrimary();
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
|
||||
return ParseBinOpRHS(0, std::move(LHS));
|
||||
return parseBinOpRHS(0, std::move(lhs));
|
||||
}
|
||||
|
||||
/// type ::= < shape_list >
|
||||
/// shape_list ::= num | num , shape_list
|
||||
std::unique_ptr<VarType> ParseType() {
|
||||
std::unique_ptr<VarType> parseType() {
|
||||
if (lexer.getCurToken() != '<')
|
||||
return parseError<VarType>("<", "to begin type");
|
||||
lexer.getNextToken(); // eat <
|
||||
|
@ -319,7 +321,7 @@ private:
|
|||
/// and identifier and an optional type (shape specification) before the
|
||||
/// initializer.
|
||||
/// decl ::= var identifier [ type ] = expr
|
||||
std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
|
||||
std::unique_ptr<VarDeclExprAST> parseDeclaration() {
|
||||
if (lexer.getCurToken() != tok_var)
|
||||
return parseError<VarDeclExprAST>("var", "to begin declaration");
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -333,7 +335,7 @@ private:
|
|||
|
||||
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
|
||||
if (lexer.getCurToken() == '<') {
|
||||
type = ParseType();
|
||||
type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -341,7 +343,7 @@ private:
|
|||
if (!type)
|
||||
type = std::make_unique<VarType>();
|
||||
lexer.consume(Token('='));
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
|
||||
std::move(*type), std::move(expr));
|
||||
}
|
||||
|
@ -352,7 +354,7 @@ private:
|
|||
/// block ::= { expression_list }
|
||||
/// expression_list ::= block_expr ; expression_list
|
||||
/// block_expr ::= decl | "return" | expr
|
||||
std::unique_ptr<ExprASTList> ParseBlock() {
|
||||
std::unique_ptr<ExprASTList> parseBlock() {
|
||||
if (lexer.getCurToken() != '{')
|
||||
return parseError<ExprASTList>("{", "to begin block");
|
||||
lexer.consume(Token('{'));
|
||||
|
@ -366,19 +368,19 @@ private:
|
|||
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
|
||||
if (lexer.getCurToken() == tok_var) {
|
||||
// Variable declaration
|
||||
auto varDecl = ParseDeclaration();
|
||||
auto varDecl = parseDeclaration();
|
||||
if (!varDecl)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(varDecl));
|
||||
} else if (lexer.getCurToken() == tok_return) {
|
||||
// Return statement
|
||||
auto ret = ParseReturn();
|
||||
auto ret = parseReturn();
|
||||
if (!ret)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(ret));
|
||||
} else {
|
||||
// General expression
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(expr));
|
||||
|
@ -401,13 +403,13 @@ private:
|
|||
|
||||
/// prototype ::= def id '(' decl_list ')'
|
||||
/// decl_list ::= identifier | identifier, decl_list
|
||||
std::unique_ptr<PrototypeAST> ParsePrototype() {
|
||||
std::unique_ptr<PrototypeAST> parsePrototype() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_def);
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<PrototypeAST>("function name", "in prototype");
|
||||
|
||||
std::string FnName = lexer.getId();
|
||||
std::string fnName = lexer.getId();
|
||||
lexer.consume(tok_identifier);
|
||||
|
||||
if (lexer.getCurToken() != '(')
|
||||
|
@ -435,7 +437,7 @@ private:
|
|||
|
||||
// success.
|
||||
lexer.consume(Token(')'));
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), FnName,
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), fnName,
|
||||
std::move(args));
|
||||
}
|
||||
|
||||
|
@ -443,18 +445,18 @@ private:
|
|||
/// `def` keyword, followed by a block containing a list of expressions.
|
||||
///
|
||||
/// definition ::= prototype block
|
||||
std::unique_ptr<FunctionAST> ParseDefinition() {
|
||||
auto Proto = ParsePrototype();
|
||||
if (!Proto)
|
||||
std::unique_ptr<FunctionAST> parseDefinition() {
|
||||
auto proto = parsePrototype();
|
||||
if (!proto)
|
||||
return nullptr;
|
||||
|
||||
if (auto block = ParseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
|
||||
if (auto block = parseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Get the precedence of the pending binary operator token.
|
||||
int GetTokPrecedence() {
|
||||
int getTokPrecedence() {
|
||||
if (!isascii(lexer.getCurToken()))
|
||||
return -1;
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ private:
|
|||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "toy/AST.h"
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -40,22 +41,22 @@ struct Indent {
|
|||
/// the way. The only data member is the current indentation level.
|
||||
class ASTDumper {
|
||||
public:
|
||||
void dump(ModuleAST *Node);
|
||||
void dump(ModuleAST *node);
|
||||
|
||||
private:
|
||||
void dump(VarType &type);
|
||||
void dump(const VarType &type);
|
||||
void dump(VarDeclExprAST *varDecl);
|
||||
void dump(ExprAST *expr);
|
||||
void dump(ExprASTList *exprList);
|
||||
void dump(NumberExprAST *num);
|
||||
void dump(LiteralExprAST *Node);
|
||||
void dump(VariableExprAST *Node);
|
||||
void dump(ReturnExprAST *Node);
|
||||
void dump(BinaryExprAST *Node);
|
||||
void dump(CallExprAST *Node);
|
||||
void dump(PrintExprAST *Node);
|
||||
void dump(PrototypeAST *Node);
|
||||
void dump(FunctionAST *Node);
|
||||
void dump(LiteralExprAST *node);
|
||||
void dump(VariableExprAST *node);
|
||||
void dump(ReturnExprAST *node);
|
||||
void dump(BinaryExprAST *node);
|
||||
void dump(CallExprAST *node);
|
||||
void dump(PrintExprAST *node);
|
||||
void dump(PrototypeAST *node);
|
||||
void dump(FunctionAST *node);
|
||||
|
||||
// Actually print spaces matching the current indentation level
|
||||
void indent() {
|
||||
|
@ -68,8 +69,8 @@ private:
|
|||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *Node) {
|
||||
const auto &loc = Node->loc();
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
.str();
|
||||
|
@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
|
|||
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||
/// We print out such array with the dimensions spelled out at every level:
|
||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||
void printLitHelper(ExprAST *lit_or_num) {
|
||||
void printLitHelper(ExprAST *litOrNum) {
|
||||
// Inside a literal expression we can have either a number or another literal
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||
llvm::errs() << num->getValue();
|
||||
return;
|
||||
}
|
||||
auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
|
||||
auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
|
||||
|
||||
// Print the dimension for this literal first
|
||||
llvm::errs() << "<";
|
||||
{
|
||||
const char *sep = "";
|
||||
for (auto dim : literal->getDims()) {
|
||||
llvm::errs() << sep << dim;
|
||||
sep = ", ";
|
||||
}
|
||||
}
|
||||
mlir::interleaveComma(literal->getDims(), llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
|
||||
// Now print the content, recursing on every element of the list
|
||||
llvm::errs() << "[ ";
|
||||
const char *sep = "";
|
||||
for (auto &elt : literal->getValues()) {
|
||||
llvm::errs() << sep;
|
||||
printLitHelper(elt.get());
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(literal->getValues(), llvm::errs(),
|
||||
[&](auto &elt) { printLitHelper(elt.get()); });
|
||||
llvm::errs() << "]";
|
||||
}
|
||||
|
||||
/// Print a literal, see the recursive helper above for the implementation.
|
||||
void ASTDumper::dump(LiteralExprAST *Node) {
|
||||
void ASTDumper::dump(LiteralExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Literal: ";
|
||||
printLitHelper(Node);
|
||||
llvm::errs() << " " << loc(Node) << "\n";
|
||||
printLitHelper(node);
|
||||
llvm::errs() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Print a variable reference (just a name).
|
||||
void ASTDumper::dump(VariableExprAST *Node) {
|
||||
void ASTDumper::dump(VariableExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
|
||||
llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Return statement print the return and its (optional) argument.
|
||||
void ASTDumper::dump(ReturnExprAST *Node) {
|
||||
void ASTDumper::dump(ReturnExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Return\n";
|
||||
if (Node->getExpr().hasValue())
|
||||
return dump(*Node->getExpr());
|
||||
if (node->getExpr().hasValue())
|
||||
return dump(*node->getExpr());
|
||||
{
|
||||
INDENT();
|
||||
llvm::errs() << "(void)\n";
|
||||
|
@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
|
|||
}
|
||||
|
||||
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
|
||||
void ASTDumper::dump(BinaryExprAST *Node) {
|
||||
void ASTDumper::dump(BinaryExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
|
||||
dump(Node->getLHS());
|
||||
dump(Node->getRHS());
|
||||
llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
|
||||
dump(node->getLHS());
|
||||
dump(node->getRHS());
|
||||
}
|
||||
|
||||
/// Print a call expression, first the callee name and the list of args by
|
||||
/// recursing into each individual argument.
|
||||
void ASTDumper::dump(CallExprAST *Node) {
|
||||
void ASTDumper::dump(CallExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
|
||||
for (auto &arg : Node->getArgs())
|
||||
llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
|
||||
for (auto &arg : node->getArgs())
|
||||
dump(arg.get());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a builtin print call, first the builtin name and then the argument.
|
||||
void ASTDumper::dump(PrintExprAST *Node) {
|
||||
void ASTDumper::dump(PrintExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Print [ " << loc(Node) << "\n";
|
||||
dump(Node->getArg());
|
||||
llvm::errs() << "Print [ " << loc(node) << "\n";
|
||||
dump(node->getArg());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print type: only the shape is printed in between '<' and '>'
|
||||
void ASTDumper::dump(VarType &type) {
|
||||
void ASTDumper::dump(const VarType &type) {
|
||||
llvm::errs() << "<";
|
||||
const char *sep = "";
|
||||
for (auto shape : type.shape) {
|
||||
llvm::errs() << sep << shape;
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(type.shape, llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
}
|
||||
|
||||
/// Print a function prototype, first the function name, and then the list of
|
||||
/// parameters names.
|
||||
void ASTDumper::dump(PrototypeAST *Node) {
|
||||
void ASTDumper::dump(PrototypeAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
|
||||
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
|
||||
indent();
|
||||
llvm::errs() << "Params: [";
|
||||
const char *sep = "";
|
||||
for (auto &arg : Node->getArgs()) {
|
||||
llvm::errs() << sep << arg->getName();
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(node->getArgs(), llvm::errs(),
|
||||
[](auto &arg) { llvm::errs() << arg->getName(); });
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a function, first the prototype and then the body.
|
||||
void ASTDumper::dump(FunctionAST *Node) {
|
||||
void ASTDumper::dump(FunctionAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Function \n";
|
||||
dump(Node->getProto());
|
||||
dump(Node->getBody());
|
||||
dump(node->getProto());
|
||||
dump(node->getBody());
|
||||
}
|
||||
|
||||
/// Print a module, actually loop over the functions and print them in sequence.
|
||||
void ASTDumper::dump(ModuleAST *Node) {
|
||||
void ASTDumper::dump(ModuleAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Module:\n";
|
||||
for (auto &F : *Node)
|
||||
dump(&F);
|
||||
for (auto &f : *node)
|
||||
dump(&f);
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
|
|
@ -63,16 +63,16 @@ static cl::opt<enum Action> emitAction(
|
|||
|
||||
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
||||
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||
if (std::error_code EC = FileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
if (std::error_code ec = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
auto buffer = FileOrErr.get()->getBuffer();
|
||||
auto buffer = fileOrErr.get()->getBuffer();
|
||||
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
|
||||
Parser parser(lexer);
|
||||
return parser.ParseModule();
|
||||
return parser.parseModule();
|
||||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
|
|
|
@ -54,7 +54,6 @@ public:
|
|||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
: kind(kind), location(location) {}
|
||||
|
||||
virtual ~ExprAST() = default;
|
||||
|
||||
ExprASTKind getKind() const { return kind; }
|
||||
|
@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
|
|||
double Val;
|
||||
|
||||
public:
|
||||
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
|
||||
NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
|
||||
|
||||
double getValue() { return Val; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
/// Expression class for a literal value.
|
||||
|
@ -93,10 +92,11 @@ public:
|
|||
: ExprAST(Expr_Literal, loc), values(std::move(values)),
|
||||
dims(std::move(dims)) {}
|
||||
|
||||
std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
|
||||
std::vector<int64_t> &getDims() { return dims; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
|
||||
llvm::ArrayRef<int64_t> getDims() { return dims; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
|
||||
};
|
||||
|
||||
/// Expression class for referencing a variable, like "a".
|
||||
|
@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
|
|||
std::string name;
|
||||
|
||||
public:
|
||||
VariableExprAST(Location loc, const std::string &name)
|
||||
VariableExprAST(Location loc, llvm::StringRef name)
|
||||
: ExprAST(Expr_Var, loc), name(name) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
/// Expression class for defining a variable.
|
||||
|
@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
|
|||
std::unique_ptr<ExprAST> initVal;
|
||||
|
||||
public:
|
||||
VarDeclExprAST(Location loc, const std::string &name, VarType type,
|
||||
VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
|
||||
std::unique_ptr<ExprAST> initVal)
|
||||
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
|
||||
initVal(std::move(initVal)) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
ExprAST *getInitVal() { return initVal.get(); }
|
||||
VarType &getType() { return type; }
|
||||
const VarType &getType() { return type; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
/// Expression class for a return operator.
|
||||
|
@ -144,61 +144,61 @@ public:
|
|||
llvm::Optional<ExprAST *> getExpr() {
|
||||
if (expr.hasValue())
|
||||
return expr->get();
|
||||
return llvm::NoneType();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
|
||||
};
|
||||
|
||||
/// Expression class for a binary operator.
|
||||
class BinaryExprAST : public ExprAST {
|
||||
char Op;
|
||||
std::unique_ptr<ExprAST> LHS, RHS;
|
||||
char op;
|
||||
std::unique_ptr<ExprAST> lhs, rhs;
|
||||
|
||||
public:
|
||||
char getOp() { return Op; }
|
||||
ExprAST *getLHS() { return LHS.get(); }
|
||||
ExprAST *getRHS() { return RHS.get(); }
|
||||
char getOp() { return op; }
|
||||
ExprAST *getLHS() { return lhs.get(); }
|
||||
ExprAST *getRHS() { return rhs.get(); }
|
||||
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
|
||||
std::unique_ptr<ExprAST> RHS)
|
||||
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
|
||||
RHS(std::move(RHS)) {}
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
|
||||
std::unique_ptr<ExprAST> rhs)
|
||||
: ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
|
||||
rhs(std::move(rhs)) {}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
|
||||
};
|
||||
|
||||
/// Expression class for function calls.
|
||||
class CallExprAST : public ExprAST {
|
||||
std::string Callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::string callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
|
||||
public:
|
||||
CallExprAST(Location loc, const std::string &Callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> Args)
|
||||
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
|
||||
CallExprAST(Location loc, const std::string &callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> args)
|
||||
: ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
|
||||
|
||||
llvm::StringRef getCallee() { return Callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
|
||||
llvm::StringRef getCallee() { return callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
|
||||
};
|
||||
|
||||
/// Expression class for builtin print calls.
|
||||
class PrintExprAST : public ExprAST {
|
||||
std::unique_ptr<ExprAST> Arg;
|
||||
std::unique_ptr<ExprAST> arg;
|
||||
|
||||
public:
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
|
||||
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
|
||||
: ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
|
||||
|
||||
ExprAST *getArg() { return Arg.get(); }
|
||||
ExprAST *getArg() { return arg.get(); }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
|
||||
};
|
||||
|
||||
/// This class represents the "prototype" for a function, which captures its
|
||||
|
@ -215,23 +215,21 @@ public:
|
|||
: location(location), name(name), args(std::move(args)) {}
|
||||
|
||||
const Location &loc() { return location; }
|
||||
const std::string &getName() const { return name; }
|
||||
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
|
||||
return args;
|
||||
}
|
||||
llvm::StringRef getName() const { return name; }
|
||||
llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
|
||||
};
|
||||
|
||||
/// This class represents a function definition itself.
|
||||
class FunctionAST {
|
||||
std::unique_ptr<PrototypeAST> Proto;
|
||||
std::unique_ptr<ExprASTList> Body;
|
||||
std::unique_ptr<PrototypeAST> proto;
|
||||
std::unique_ptr<ExprASTList> body;
|
||||
|
||||
public:
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
|
||||
std::unique_ptr<ExprASTList> Body)
|
||||
: Proto(std::move(Proto)), Body(std::move(Body)) {}
|
||||
PrototypeAST *getProto() { return Proto.get(); }
|
||||
ExprASTList *getBody() { return Body.get(); }
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> proto,
|
||||
std::unique_ptr<ExprASTList> body)
|
||||
: proto(std::move(proto)), body(std::move(body)) {}
|
||||
PrototypeAST *getProto() { return proto.get(); }
|
||||
ExprASTList *getBody() { return body.get(); }
|
||||
};
|
||||
|
||||
/// This class represents a list of functions to be processed together
|
||||
|
|
|
@ -89,13 +89,13 @@ public:
|
|||
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
|
||||
llvm::StringRef getId() {
|
||||
assert(curTok == tok_identifier);
|
||||
return IdentifierStr;
|
||||
return identifierStr;
|
||||
}
|
||||
|
||||
/// Return the current number (prereq: getCurToken() == tok_number)
|
||||
double getValue() {
|
||||
assert(curTok == tok_number);
|
||||
return NumVal;
|
||||
return numVal;
|
||||
}
|
||||
|
||||
/// Return the location for the beginning of the current token.
|
||||
|
@ -135,56 +135,58 @@ private:
|
|||
/// Return the next token from standard input.
|
||||
Token getTok() {
|
||||
// Skip any whitespace.
|
||||
while (isspace(LastChar))
|
||||
LastChar = Token(getNextChar());
|
||||
while (isspace(lastChar))
|
||||
lastChar = Token(getNextChar());
|
||||
|
||||
// Save the current location before reading the token characters.
|
||||
lastLocation.line = curLineNum;
|
||||
lastLocation.col = curCol;
|
||||
|
||||
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
IdentifierStr = (char)LastChar;
|
||||
while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
|
||||
IdentifierStr += (char)LastChar;
|
||||
// Identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
if (isalpha(lastChar)) {
|
||||
identifierStr = (char)lastChar;
|
||||
while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
|
||||
identifierStr += (char)lastChar;
|
||||
|
||||
if (IdentifierStr == "return")
|
||||
if (identifierStr == "return")
|
||||
return tok_return;
|
||||
if (IdentifierStr == "def")
|
||||
if (identifierStr == "def")
|
||||
return tok_def;
|
||||
if (IdentifierStr == "var")
|
||||
if (identifierStr == "var")
|
||||
return tok_var;
|
||||
return tok_identifier;
|
||||
}
|
||||
|
||||
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
|
||||
std::string NumStr;
|
||||
// Number: [0-9.]+
|
||||
if (isdigit(lastChar) || lastChar == '.') {
|
||||
std::string numStr;
|
||||
do {
|
||||
NumStr += LastChar;
|
||||
LastChar = Token(getNextChar());
|
||||
} while (isdigit(LastChar) || LastChar == '.');
|
||||
numStr += lastChar;
|
||||
lastChar = Token(getNextChar());
|
||||
} while (isdigit(lastChar) || lastChar == '.');
|
||||
|
||||
NumVal = strtod(NumStr.c_str(), nullptr);
|
||||
numVal = strtod(numStr.c_str(), nullptr);
|
||||
return tok_number;
|
||||
}
|
||||
|
||||
if (LastChar == '#') {
|
||||
if (lastChar == '#') {
|
||||
// Comment until end of line.
|
||||
do
|
||||
LastChar = Token(getNextChar());
|
||||
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
|
||||
do {
|
||||
lastChar = Token(getNextChar());
|
||||
} while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
|
||||
|
||||
if (LastChar != EOF)
|
||||
if (lastChar != EOF)
|
||||
return getTok();
|
||||
}
|
||||
|
||||
// Check for end of file. Don't eat the EOF.
|
||||
if (LastChar == EOF)
|
||||
if (lastChar == EOF)
|
||||
return tok_eof;
|
||||
|
||||
// Otherwise, just return the character as its ascii value.
|
||||
Token ThisChar = Token(LastChar);
|
||||
LastChar = Token(getNextChar());
|
||||
return ThisChar;
|
||||
Token thisChar = Token(lastChar);
|
||||
lastChar = Token(getNextChar());
|
||||
return thisChar;
|
||||
}
|
||||
|
||||
/// The last token read from the input.
|
||||
|
@ -194,15 +196,15 @@ private:
|
|||
Location lastLocation;
|
||||
|
||||
/// If the current Token is an identifier, this string contains the value.
|
||||
std::string IdentifierStr;
|
||||
std::string identifierStr;
|
||||
|
||||
/// If the current Token is a number, this contains the value.
|
||||
double NumVal = 0;
|
||||
double numVal = 0;
|
||||
|
||||
/// The last value returned by getNextChar(). We need to keep it around as we
|
||||
/// always need to read ahead one character to decide when to end a token and
|
||||
/// we can't put it back in the stream after reading from it.
|
||||
Token LastChar = Token(' ');
|
||||
Token lastChar = Token(' ');
|
||||
|
||||
/// Keep track of the current line number in the input stream
|
||||
int curLineNum = 0;
|
||||
|
|
|
@ -48,13 +48,13 @@ public:
|
|||
Parser(Lexer &lexer) : lexer(lexer) {}
|
||||
|
||||
/// Parse a full Module. A module is a list of function definitions.
|
||||
std::unique_ptr<ModuleAST> ParseModule() {
|
||||
std::unique_ptr<ModuleAST> parseModule() {
|
||||
lexer.getNextToken(); // prime the lexer
|
||||
|
||||
// Parse functions one at a time and accumulate in this vector.
|
||||
std::vector<FunctionAST> functions;
|
||||
while (auto F = ParseDefinition()) {
|
||||
functions.push_back(std::move(*F));
|
||||
while (auto f = parseDefinition()) {
|
||||
functions.push_back(std::move(*f));
|
||||
if (lexer.getCurToken() == tok_eof)
|
||||
break;
|
||||
}
|
||||
|
@ -70,14 +70,14 @@ private:
|
|||
|
||||
/// Parse a return statement.
|
||||
/// return :== return ; | return expr ;
|
||||
std::unique_ptr<ReturnExprAST> ParseReturn() {
|
||||
std::unique_ptr<ReturnExprAST> parseReturn() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_return);
|
||||
|
||||
// return takes an optional argument
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
if (lexer.getCurToken() != ';') {
|
||||
expr = ParseExpression();
|
||||
expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -86,18 +86,18 @@ private:
|
|||
|
||||
/// Parse a literal number.
|
||||
/// numberexpr ::= number
|
||||
std::unique_ptr<ExprAST> ParseNumberExpr() {
|
||||
std::unique_ptr<ExprAST> parseNumberExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
auto Result =
|
||||
auto result =
|
||||
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
|
||||
lexer.consume(tok_number);
|
||||
return std::move(Result);
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/// Parse a literal array expression.
|
||||
/// tensorLiteral ::= [ literalList ] | number
|
||||
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
||||
std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
|
||||
std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(Token('['));
|
||||
|
||||
|
@ -108,13 +108,13 @@ private:
|
|||
do {
|
||||
// We can have either another nested array or a number literal.
|
||||
if (lexer.getCurToken() == '[') {
|
||||
values.push_back(ParseTensorLiteralExpr());
|
||||
values.push_back(parseTensorLiteralExpr());
|
||||
if (!values.back())
|
||||
return nullptr; // parse error in the nested array.
|
||||
} else {
|
||||
if (lexer.getCurToken() != tok_number)
|
||||
return parseError<ExprAST>("<num> or [", "in literal expression");
|
||||
values.push_back(ParseNumberExpr());
|
||||
values.push_back(parseNumberExpr());
|
||||
}
|
||||
|
||||
// End of this list on ']'
|
||||
|
@ -130,8 +130,10 @@ private:
|
|||
if (values.empty())
|
||||
return parseError<ExprAST>("<something>", "to fill literal expression");
|
||||
lexer.getNextToken(); // eat ]
|
||||
|
||||
/// Fill in the dimensions now. First the current nesting level:
|
||||
dims.push_back(values.size());
|
||||
|
||||
/// If there is any nested array, process all of them and ensure that
|
||||
/// dimensions are uniform.
|
||||
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
|
||||
|
@ -143,7 +145,7 @@ private:
|
|||
"inside literal expression");
|
||||
|
||||
// Append the nested dimensions to the current level
|
||||
auto &firstDims = firstLiteral->getDims();
|
||||
auto firstDims = firstLiteral->getDims();
|
||||
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
||||
|
||||
// Sanity check that shape is uniform across all elements of the list.
|
||||
|
@ -162,22 +164,22 @@ private:
|
|||
}
|
||||
|
||||
/// parenexpr ::= '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseParenExpr() {
|
||||
std::unique_ptr<ExprAST> parseParenExpr() {
|
||||
lexer.getNextToken(); // eat (.
|
||||
auto V = ParseExpression();
|
||||
if (!V)
|
||||
auto v = parseExpression();
|
||||
if (!v)
|
||||
return nullptr;
|
||||
|
||||
if (lexer.getCurToken() != ')')
|
||||
return parseError<ExprAST>(")", "to close expression with parentheses");
|
||||
lexer.consume(Token(')'));
|
||||
return V;
|
||||
return v;
|
||||
}
|
||||
|
||||
/// identifierexpr
|
||||
/// ::= identifier
|
||||
/// ::= identifier '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
|
||||
std::unique_ptr<ExprAST> parseIdentifierExpr() {
|
||||
std::string name = lexer.getId();
|
||||
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -188,11 +190,11 @@ private:
|
|||
|
||||
// This is a function call.
|
||||
lexer.consume(Token('('));
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
if (lexer.getCurToken() != ')') {
|
||||
while (true) {
|
||||
if (auto Arg = ParseExpression())
|
||||
Args.push_back(std::move(Arg));
|
||||
if (auto arg = parseExpression())
|
||||
args.push_back(std::move(arg));
|
||||
else
|
||||
return nullptr;
|
||||
|
||||
|
@ -208,14 +210,14 @@ private:
|
|||
|
||||
// It can be a builtin call to print
|
||||
if (name == "print") {
|
||||
if (Args.size() != 1)
|
||||
if (args.size() != 1)
|
||||
return parseError<ExprAST>("<single arg>", "as argument to print()");
|
||||
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
|
||||
}
|
||||
|
||||
// Call to a user-defined function
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
|
||||
}
|
||||
|
||||
/// primary
|
||||
|
@ -223,20 +225,20 @@ private:
|
|||
/// ::= numberexpr
|
||||
/// ::= parenexpr
|
||||
/// ::= tensorliteral
|
||||
std::unique_ptr<ExprAST> ParsePrimary() {
|
||||
std::unique_ptr<ExprAST> parsePrimary() {
|
||||
switch (lexer.getCurToken()) {
|
||||
default:
|
||||
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
||||
<< "' when expecting an expression\n";
|
||||
return nullptr;
|
||||
case tok_identifier:
|
||||
return ParseIdentifierExpr();
|
||||
return parseIdentifierExpr();
|
||||
case tok_number:
|
||||
return ParseNumberExpr();
|
||||
return parseNumberExpr();
|
||||
case '(':
|
||||
return ParseParenExpr();
|
||||
return parseParenExpr();
|
||||
case '[':
|
||||
return ParseTensorLiteralExpr();
|
||||
return parseTensorLiteralExpr();
|
||||
case ';':
|
||||
return nullptr;
|
||||
case '}':
|
||||
|
@ -248,54 +250,54 @@ private:
|
|||
/// argument indicates the precedence of the current binary operator.
|
||||
///
|
||||
/// binoprhs ::= ('+' primary)*
|
||||
std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
|
||||
std::unique_ptr<ExprAST> LHS) {
|
||||
std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
|
||||
std::unique_ptr<ExprAST> lhs) {
|
||||
// If this is a binop, find its precedence.
|
||||
while (true) {
|
||||
int TokPrec = GetTokPrecedence();
|
||||
int tokPrec = getTokPrecedence();
|
||||
|
||||
// If this is a binop that binds at least as tightly as the current binop,
|
||||
// consume it, otherwise we are done.
|
||||
if (TokPrec < ExprPrec)
|
||||
return LHS;
|
||||
if (tokPrec < exprPrec)
|
||||
return lhs;
|
||||
|
||||
// Okay, we know this is a binop.
|
||||
int BinOp = lexer.getCurToken();
|
||||
lexer.consume(Token(BinOp));
|
||||
int binOp = lexer.getCurToken();
|
||||
lexer.consume(Token(binOp));
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
||||
// Parse the primary expression after the binary operator.
|
||||
auto RHS = ParsePrimary();
|
||||
if (!RHS)
|
||||
auto rhs = parsePrimary();
|
||||
if (!rhs)
|
||||
return parseError<ExprAST>("expression", "to complete binary operator");
|
||||
|
||||
// If BinOp binds less tightly with RHS than the operator after RHS, let
|
||||
// the pending operator take RHS as its LHS.
|
||||
int NextPrec = GetTokPrecedence();
|
||||
if (TokPrec < NextPrec) {
|
||||
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
|
||||
if (!RHS)
|
||||
// If BinOp binds less tightly with rhs than the operator after rhs, let
|
||||
// the pending operator take rhs as its lhs.
|
||||
int nextPrec = getTokPrecedence();
|
||||
if (tokPrec < nextPrec) {
|
||||
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Merge LHS/RHS.
|
||||
LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
|
||||
std::move(LHS), std::move(RHS));
|
||||
// Merge lhs/RHS.
|
||||
lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
|
||||
std::move(lhs), std::move(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
/// expression::= primary binoprhs
|
||||
std::unique_ptr<ExprAST> ParseExpression() {
|
||||
auto LHS = ParsePrimary();
|
||||
if (!LHS)
|
||||
/// expression::= primary binop rhs
|
||||
std::unique_ptr<ExprAST> parseExpression() {
|
||||
auto lhs = parsePrimary();
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
|
||||
return ParseBinOpRHS(0, std::move(LHS));
|
||||
return parseBinOpRHS(0, std::move(lhs));
|
||||
}
|
||||
|
||||
/// type ::= < shape_list >
|
||||
/// shape_list ::= num | num , shape_list
|
||||
std::unique_ptr<VarType> ParseType() {
|
||||
std::unique_ptr<VarType> parseType() {
|
||||
if (lexer.getCurToken() != '<')
|
||||
return parseError<VarType>("<", "to begin type");
|
||||
lexer.getNextToken(); // eat <
|
||||
|
@ -319,7 +321,7 @@ private:
|
|||
/// and identifier and an optional type (shape specification) before the
|
||||
/// initializer.
|
||||
/// decl ::= var identifier [ type ] = expr
|
||||
std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
|
||||
std::unique_ptr<VarDeclExprAST> parseDeclaration() {
|
||||
if (lexer.getCurToken() != tok_var)
|
||||
return parseError<VarDeclExprAST>("var", "to begin declaration");
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -333,7 +335,7 @@ private:
|
|||
|
||||
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
|
||||
if (lexer.getCurToken() == '<') {
|
||||
type = ParseType();
|
||||
type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -341,7 +343,7 @@ private:
|
|||
if (!type)
|
||||
type = std::make_unique<VarType>();
|
||||
lexer.consume(Token('='));
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
|
||||
std::move(*type), std::move(expr));
|
||||
}
|
||||
|
@ -352,7 +354,7 @@ private:
|
|||
/// block ::= { expression_list }
|
||||
/// expression_list ::= block_expr ; expression_list
|
||||
/// block_expr ::= decl | "return" | expr
|
||||
std::unique_ptr<ExprASTList> ParseBlock() {
|
||||
std::unique_ptr<ExprASTList> parseBlock() {
|
||||
if (lexer.getCurToken() != '{')
|
||||
return parseError<ExprASTList>("{", "to begin block");
|
||||
lexer.consume(Token('{'));
|
||||
|
@ -366,19 +368,19 @@ private:
|
|||
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
|
||||
if (lexer.getCurToken() == tok_var) {
|
||||
// Variable declaration
|
||||
auto varDecl = ParseDeclaration();
|
||||
auto varDecl = parseDeclaration();
|
||||
if (!varDecl)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(varDecl));
|
||||
} else if (lexer.getCurToken() == tok_return) {
|
||||
// Return statement
|
||||
auto ret = ParseReturn();
|
||||
auto ret = parseReturn();
|
||||
if (!ret)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(ret));
|
||||
} else {
|
||||
// General expression
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(expr));
|
||||
|
@ -401,13 +403,13 @@ private:
|
|||
|
||||
/// prototype ::= def id '(' decl_list ')'
|
||||
/// decl_list ::= identifier | identifier, decl_list
|
||||
std::unique_ptr<PrototypeAST> ParsePrototype() {
|
||||
std::unique_ptr<PrototypeAST> parsePrototype() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_def);
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<PrototypeAST>("function name", "in prototype");
|
||||
|
||||
std::string FnName = lexer.getId();
|
||||
std::string fnName = lexer.getId();
|
||||
lexer.consume(tok_identifier);
|
||||
|
||||
if (lexer.getCurToken() != '(')
|
||||
|
@ -435,7 +437,7 @@ private:
|
|||
|
||||
// success.
|
||||
lexer.consume(Token(')'));
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), FnName,
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), fnName,
|
||||
std::move(args));
|
||||
}
|
||||
|
||||
|
@ -443,18 +445,18 @@ private:
|
|||
/// `def` keyword, followed by a block containing a list of expressions.
|
||||
///
|
||||
/// definition ::= prototype block
|
||||
std::unique_ptr<FunctionAST> ParseDefinition() {
|
||||
auto Proto = ParsePrototype();
|
||||
if (!Proto)
|
||||
std::unique_ptr<FunctionAST> parseDefinition() {
|
||||
auto proto = parsePrototype();
|
||||
if (!proto)
|
||||
return nullptr;
|
||||
|
||||
if (auto block = ParseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
|
||||
if (auto block = parseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Get the precedence of the pending binary operator token.
|
||||
int GetTokPrecedence() {
|
||||
int getTokPrecedence() {
|
||||
if (!isascii(lexer.getCurToken()))
|
||||
return -1;
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ private:
|
|||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "toy/AST.h"
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -40,22 +41,22 @@ struct Indent {
|
|||
/// the way. The only data member is the current indentation level.
|
||||
class ASTDumper {
|
||||
public:
|
||||
void dump(ModuleAST *Node);
|
||||
void dump(ModuleAST *node);
|
||||
|
||||
private:
|
||||
void dump(VarType &type);
|
||||
void dump(const VarType &type);
|
||||
void dump(VarDeclExprAST *varDecl);
|
||||
void dump(ExprAST *expr);
|
||||
void dump(ExprASTList *exprList);
|
||||
void dump(NumberExprAST *num);
|
||||
void dump(LiteralExprAST *Node);
|
||||
void dump(VariableExprAST *Node);
|
||||
void dump(ReturnExprAST *Node);
|
||||
void dump(BinaryExprAST *Node);
|
||||
void dump(CallExprAST *Node);
|
||||
void dump(PrintExprAST *Node);
|
||||
void dump(PrototypeAST *Node);
|
||||
void dump(FunctionAST *Node);
|
||||
void dump(LiteralExprAST *node);
|
||||
void dump(VariableExprAST *node);
|
||||
void dump(ReturnExprAST *node);
|
||||
void dump(BinaryExprAST *node);
|
||||
void dump(CallExprAST *node);
|
||||
void dump(PrintExprAST *node);
|
||||
void dump(PrototypeAST *node);
|
||||
void dump(FunctionAST *node);
|
||||
|
||||
// Actually print spaces matching the current indentation level
|
||||
void indent() {
|
||||
|
@ -68,8 +69,8 @@ private:
|
|||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *Node) {
|
||||
const auto &loc = Node->loc();
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
.str();
|
||||
|
@ -125,60 +126,50 @@ void ASTDumper::dump(NumberExprAST *num) {
|
|||
llvm::errs() << num->getValue() << " " << loc(num) << "\n";
|
||||
}
|
||||
|
||||
/// Helper to print recurisvely a literal. This handles nested array like:
|
||||
/// Helper to print recursively a literal. This handles nested array like:
|
||||
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||
/// We print out such array with the dimensions spelled out at every level:
|
||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||
void printLitHelper(ExprAST *lit_or_num) {
|
||||
void printLitHelper(ExprAST *litOrNum) {
|
||||
// Inside a literal expression we can have either a number or another literal
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||
llvm::errs() << num->getValue();
|
||||
return;
|
||||
}
|
||||
auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
|
||||
auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
|
||||
|
||||
// Print the dimension for this literal first
|
||||
llvm::errs() << "<";
|
||||
{
|
||||
const char *sep = "";
|
||||
for (auto dim : literal->getDims()) {
|
||||
llvm::errs() << sep << dim;
|
||||
sep = ", ";
|
||||
}
|
||||
}
|
||||
mlir::interleaveComma(literal->getDims(), llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
|
||||
// Now print the content, recursing on every element of the list
|
||||
llvm::errs() << "[ ";
|
||||
const char *sep = "";
|
||||
for (auto &elt : literal->getValues()) {
|
||||
llvm::errs() << sep;
|
||||
printLitHelper(elt.get());
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(literal->getValues(), llvm::errs(),
|
||||
[&](auto &elt) { printLitHelper(elt.get()); });
|
||||
llvm::errs() << "]";
|
||||
}
|
||||
|
||||
/// Print a literal, see the recursive helper above for the implementation.
|
||||
void ASTDumper::dump(LiteralExprAST *Node) {
|
||||
void ASTDumper::dump(LiteralExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Literal: ";
|
||||
printLitHelper(Node);
|
||||
llvm::errs() << " " << loc(Node) << "\n";
|
||||
printLitHelper(node);
|
||||
llvm::errs() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Print a variable reference (just a name).
|
||||
void ASTDumper::dump(VariableExprAST *Node) {
|
||||
void ASTDumper::dump(VariableExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
|
||||
llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Return statement print the return and its (optional) argument.
|
||||
void ASTDumper::dump(ReturnExprAST *Node) {
|
||||
void ASTDumper::dump(ReturnExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Return\n";
|
||||
if (Node->getExpr().hasValue())
|
||||
return dump(*Node->getExpr());
|
||||
if (node->getExpr().hasValue())
|
||||
return dump(*node->getExpr());
|
||||
{
|
||||
INDENT();
|
||||
llvm::errs() << "(void)\n";
|
||||
|
@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
|
|||
}
|
||||
|
||||
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
|
||||
void ASTDumper::dump(BinaryExprAST *Node) {
|
||||
void ASTDumper::dump(BinaryExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
|
||||
dump(Node->getLHS());
|
||||
dump(Node->getRHS());
|
||||
llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
|
||||
dump(node->getLHS());
|
||||
dump(node->getRHS());
|
||||
}
|
||||
|
||||
/// Print a call expression, first the callee name and the list of args by
|
||||
/// recursing into each individual argument.
|
||||
void ASTDumper::dump(CallExprAST *Node) {
|
||||
void ASTDumper::dump(CallExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
|
||||
for (auto &arg : Node->getArgs())
|
||||
llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
|
||||
for (auto &arg : node->getArgs())
|
||||
dump(arg.get());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a builtin print call, first the builtin name and then the argument.
|
||||
void ASTDumper::dump(PrintExprAST *Node) {
|
||||
void ASTDumper::dump(PrintExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Print [ " << loc(Node) << "\n";
|
||||
dump(Node->getArg());
|
||||
llvm::errs() << "Print [ " << loc(node) << "\n";
|
||||
dump(node->getArg());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print type: only the shape is printed in between '<' and '>'
|
||||
void ASTDumper::dump(VarType &type) {
|
||||
void ASTDumper::dump(const VarType &type) {
|
||||
llvm::errs() << "<";
|
||||
const char *sep = "";
|
||||
for (auto shape : type.shape) {
|
||||
llvm::errs() << sep << shape;
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(type.shape, llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
}
|
||||
|
||||
/// Print a function prototype, first the function name, and then the list of
|
||||
/// parameters names.
|
||||
void ASTDumper::dump(PrototypeAST *Node) {
|
||||
void ASTDumper::dump(PrototypeAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
|
||||
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
|
||||
indent();
|
||||
llvm::errs() << "Params: [";
|
||||
const char *sep = "";
|
||||
for (auto &arg : Node->getArgs()) {
|
||||
llvm::errs() << sep << arg->getName();
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(node->getArgs(), llvm::errs(),
|
||||
[](auto &arg) { llvm::errs() << arg->getName(); });
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a function, first the prototype and then the body.
|
||||
void ASTDumper::dump(FunctionAST *Node) {
|
||||
void ASTDumper::dump(FunctionAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Function \n";
|
||||
dump(Node->getProto());
|
||||
dump(Node->getBody());
|
||||
dump(node->getProto());
|
||||
dump(node->getBody());
|
||||
}
|
||||
|
||||
/// Print a module, actually loop over the functions and print them in sequence.
|
||||
void ASTDumper::dump(ModuleAST *Node) {
|
||||
void ASTDumper::dump(ModuleAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Module:\n";
|
||||
for (auto &F : *Node)
|
||||
dump(&F);
|
||||
for (auto &f : *node)
|
||||
dump(&f);
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
|
|
@ -63,20 +63,20 @@ static cl::opt<enum Action> emitAction(
|
|||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
|
||||
|
||||
static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
|
||||
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
|
||||
|
||||
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
||||
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||
if (std::error_code EC = FileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
if (std::error_code ec = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
auto buffer = FileOrErr.get()->getBuffer();
|
||||
auto buffer = fileOrErr.get()->getBuffer();
|
||||
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
|
||||
Parser parser(lexer);
|
||||
return parser.ParseModule();
|
||||
return parser.parseModule();
|
||||
}
|
||||
|
||||
int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
||||
|
@ -118,7 +118,7 @@ int dumpMLIR() {
|
|||
if (int error = loadMLIR(sourceMgr, context, module))
|
||||
return error;
|
||||
|
||||
if (EnableOpt) {
|
||||
if (enableOpt) {
|
||||
mlir::PassManager pm(&context);
|
||||
// Apply any generic pass manager command line options and run the pipeline.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
|
|
@ -54,7 +54,6 @@ public:
|
|||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
: kind(kind), location(location) {}
|
||||
|
||||
virtual ~ExprAST() = default;
|
||||
|
||||
ExprASTKind getKind() const { return kind; }
|
||||
|
@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
|
|||
double Val;
|
||||
|
||||
public:
|
||||
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
|
||||
NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
|
||||
|
||||
double getValue() { return Val; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
/// Expression class for a literal value.
|
||||
|
@ -93,10 +92,11 @@ public:
|
|||
: ExprAST(Expr_Literal, loc), values(std::move(values)),
|
||||
dims(std::move(dims)) {}
|
||||
|
||||
std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
|
||||
std::vector<int64_t> &getDims() { return dims; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
|
||||
llvm::ArrayRef<int64_t> getDims() { return dims; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
|
||||
};
|
||||
|
||||
/// Expression class for referencing a variable, like "a".
|
||||
|
@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
|
|||
std::string name;
|
||||
|
||||
public:
|
||||
VariableExprAST(Location loc, const std::string &name)
|
||||
VariableExprAST(Location loc, llvm::StringRef name)
|
||||
: ExprAST(Expr_Var, loc), name(name) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
/// Expression class for defining a variable.
|
||||
|
@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
|
|||
std::unique_ptr<ExprAST> initVal;
|
||||
|
||||
public:
|
||||
VarDeclExprAST(Location loc, const std::string &name, VarType type,
|
||||
VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
|
||||
std::unique_ptr<ExprAST> initVal)
|
||||
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
|
||||
initVal(std::move(initVal)) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
ExprAST *getInitVal() { return initVal.get(); }
|
||||
VarType &getType() { return type; }
|
||||
const VarType &getType() { return type; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
/// Expression class for a return operator.
|
||||
|
@ -144,61 +144,61 @@ public:
|
|||
llvm::Optional<ExprAST *> getExpr() {
|
||||
if (expr.hasValue())
|
||||
return expr->get();
|
||||
return llvm::NoneType();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
|
||||
};
|
||||
|
||||
/// Expression class for a binary operator.
|
||||
class BinaryExprAST : public ExprAST {
|
||||
char Op;
|
||||
std::unique_ptr<ExprAST> LHS, RHS;
|
||||
char op;
|
||||
std::unique_ptr<ExprAST> lhs, rhs;
|
||||
|
||||
public:
|
||||
char getOp() { return Op; }
|
||||
ExprAST *getLHS() { return LHS.get(); }
|
||||
ExprAST *getRHS() { return RHS.get(); }
|
||||
char getOp() { return op; }
|
||||
ExprAST *getLHS() { return lhs.get(); }
|
||||
ExprAST *getRHS() { return rhs.get(); }
|
||||
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
|
||||
std::unique_ptr<ExprAST> RHS)
|
||||
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
|
||||
RHS(std::move(RHS)) {}
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
|
||||
std::unique_ptr<ExprAST> rhs)
|
||||
: ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
|
||||
rhs(std::move(rhs)) {}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
|
||||
};
|
||||
|
||||
/// Expression class for function calls.
|
||||
class CallExprAST : public ExprAST {
|
||||
std::string Callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::string callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
|
||||
public:
|
||||
CallExprAST(Location loc, const std::string &Callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> Args)
|
||||
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
|
||||
CallExprAST(Location loc, const std::string &callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> args)
|
||||
: ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
|
||||
|
||||
llvm::StringRef getCallee() { return Callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
|
||||
llvm::StringRef getCallee() { return callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
|
||||
};
|
||||
|
||||
/// Expression class for builtin print calls.
|
||||
class PrintExprAST : public ExprAST {
|
||||
std::unique_ptr<ExprAST> Arg;
|
||||
std::unique_ptr<ExprAST> arg;
|
||||
|
||||
public:
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
|
||||
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
|
||||
: ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
|
||||
|
||||
ExprAST *getArg() { return Arg.get(); }
|
||||
ExprAST *getArg() { return arg.get(); }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
|
||||
};
|
||||
|
||||
/// This class represents the "prototype" for a function, which captures its
|
||||
|
@ -215,23 +215,21 @@ public:
|
|||
: location(location), name(name), args(std::move(args)) {}
|
||||
|
||||
const Location &loc() { return location; }
|
||||
const std::string &getName() const { return name; }
|
||||
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
|
||||
return args;
|
||||
}
|
||||
llvm::StringRef getName() const { return name; }
|
||||
llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
|
||||
};
|
||||
|
||||
/// This class represents a function definition itself.
|
||||
class FunctionAST {
|
||||
std::unique_ptr<PrototypeAST> Proto;
|
||||
std::unique_ptr<ExprASTList> Body;
|
||||
std::unique_ptr<PrototypeAST> proto;
|
||||
std::unique_ptr<ExprASTList> body;
|
||||
|
||||
public:
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
|
||||
std::unique_ptr<ExprASTList> Body)
|
||||
: Proto(std::move(Proto)), Body(std::move(Body)) {}
|
||||
PrototypeAST *getProto() { return Proto.get(); }
|
||||
ExprASTList *getBody() { return Body.get(); }
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> proto,
|
||||
std::unique_ptr<ExprASTList> body)
|
||||
: proto(std::move(proto)), body(std::move(body)) {}
|
||||
PrototypeAST *getProto() { return proto.get(); }
|
||||
ExprASTList *getBody() { return body.get(); }
|
||||
};
|
||||
|
||||
/// This class represents a list of functions to be processed together
|
||||
|
|
|
@ -89,13 +89,13 @@ public:
|
|||
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
|
||||
llvm::StringRef getId() {
|
||||
assert(curTok == tok_identifier);
|
||||
return IdentifierStr;
|
||||
return identifierStr;
|
||||
}
|
||||
|
||||
/// Return the current number (prereq: getCurToken() == tok_number)
|
||||
double getValue() {
|
||||
assert(curTok == tok_number);
|
||||
return NumVal;
|
||||
return numVal;
|
||||
}
|
||||
|
||||
/// Return the location for the beginning of the current token.
|
||||
|
@ -135,56 +135,58 @@ private:
|
|||
/// Return the next token from standard input.
|
||||
Token getTok() {
|
||||
// Skip any whitespace.
|
||||
while (isspace(LastChar))
|
||||
LastChar = Token(getNextChar());
|
||||
while (isspace(lastChar))
|
||||
lastChar = Token(getNextChar());
|
||||
|
||||
// Save the current location before reading the token characters.
|
||||
lastLocation.line = curLineNum;
|
||||
lastLocation.col = curCol;
|
||||
|
||||
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
IdentifierStr = (char)LastChar;
|
||||
while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
|
||||
IdentifierStr += (char)LastChar;
|
||||
// Identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
if (isalpha(lastChar)) {
|
||||
identifierStr = (char)lastChar;
|
||||
while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
|
||||
identifierStr += (char)lastChar;
|
||||
|
||||
if (IdentifierStr == "return")
|
||||
if (identifierStr == "return")
|
||||
return tok_return;
|
||||
if (IdentifierStr == "def")
|
||||
if (identifierStr == "def")
|
||||
return tok_def;
|
||||
if (IdentifierStr == "var")
|
||||
if (identifierStr == "var")
|
||||
return tok_var;
|
||||
return tok_identifier;
|
||||
}
|
||||
|
||||
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
|
||||
std::string NumStr;
|
||||
// Number: [0-9.]+
|
||||
if (isdigit(lastChar) || lastChar == '.') {
|
||||
std::string numStr;
|
||||
do {
|
||||
NumStr += LastChar;
|
||||
LastChar = Token(getNextChar());
|
||||
} while (isdigit(LastChar) || LastChar == '.');
|
||||
numStr += lastChar;
|
||||
lastChar = Token(getNextChar());
|
||||
} while (isdigit(lastChar) || lastChar == '.');
|
||||
|
||||
NumVal = strtod(NumStr.c_str(), nullptr);
|
||||
numVal = strtod(numStr.c_str(), nullptr);
|
||||
return tok_number;
|
||||
}
|
||||
|
||||
if (LastChar == '#') {
|
||||
if (lastChar == '#') {
|
||||
// Comment until end of line.
|
||||
do
|
||||
LastChar = Token(getNextChar());
|
||||
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
|
||||
do {
|
||||
lastChar = Token(getNextChar());
|
||||
} while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
|
||||
|
||||
if (LastChar != EOF)
|
||||
if (lastChar != EOF)
|
||||
return getTok();
|
||||
}
|
||||
|
||||
// Check for end of file. Don't eat the EOF.
|
||||
if (LastChar == EOF)
|
||||
if (lastChar == EOF)
|
||||
return tok_eof;
|
||||
|
||||
// Otherwise, just return the character as its ascii value.
|
||||
Token ThisChar = Token(LastChar);
|
||||
LastChar = Token(getNextChar());
|
||||
return ThisChar;
|
||||
Token thisChar = Token(lastChar);
|
||||
lastChar = Token(getNextChar());
|
||||
return thisChar;
|
||||
}
|
||||
|
||||
/// The last token read from the input.
|
||||
|
@ -194,15 +196,15 @@ private:
|
|||
Location lastLocation;
|
||||
|
||||
/// If the current Token is an identifier, this string contains the value.
|
||||
std::string IdentifierStr;
|
||||
std::string identifierStr;
|
||||
|
||||
/// If the current Token is a number, this contains the value.
|
||||
double NumVal = 0;
|
||||
double numVal = 0;
|
||||
|
||||
/// The last value returned by getNextChar(). We need to keep it around as we
|
||||
/// always need to read ahead one character to decide when to end a token and
|
||||
/// we can't put it back in the stream after reading from it.
|
||||
Token LastChar = Token(' ');
|
||||
Token lastChar = Token(' ');
|
||||
|
||||
/// Keep track of the current line number in the input stream
|
||||
int curLineNum = 0;
|
||||
|
|
|
@ -48,13 +48,13 @@ public:
|
|||
Parser(Lexer &lexer) : lexer(lexer) {}
|
||||
|
||||
/// Parse a full Module. A module is a list of function definitions.
|
||||
std::unique_ptr<ModuleAST> ParseModule() {
|
||||
std::unique_ptr<ModuleAST> parseModule() {
|
||||
lexer.getNextToken(); // prime the lexer
|
||||
|
||||
// Parse functions one at a time and accumulate in this vector.
|
||||
std::vector<FunctionAST> functions;
|
||||
while (auto F = ParseDefinition()) {
|
||||
functions.push_back(std::move(*F));
|
||||
while (auto f = parseDefinition()) {
|
||||
functions.push_back(std::move(*f));
|
||||
if (lexer.getCurToken() == tok_eof)
|
||||
break;
|
||||
}
|
||||
|
@ -70,14 +70,14 @@ private:
|
|||
|
||||
/// Parse a return statement.
|
||||
/// return :== return ; | return expr ;
|
||||
std::unique_ptr<ReturnExprAST> ParseReturn() {
|
||||
std::unique_ptr<ReturnExprAST> parseReturn() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_return);
|
||||
|
||||
// return takes an optional argument
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
if (lexer.getCurToken() != ';') {
|
||||
expr = ParseExpression();
|
||||
expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -86,18 +86,18 @@ private:
|
|||
|
||||
/// Parse a literal number.
|
||||
/// numberexpr ::= number
|
||||
std::unique_ptr<ExprAST> ParseNumberExpr() {
|
||||
std::unique_ptr<ExprAST> parseNumberExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
auto Result =
|
||||
auto result =
|
||||
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
|
||||
lexer.consume(tok_number);
|
||||
return std::move(Result);
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/// Parse a literal array expression.
|
||||
/// tensorLiteral ::= [ literalList ] | number
|
||||
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
||||
std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
|
||||
std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(Token('['));
|
||||
|
||||
|
@ -108,13 +108,13 @@ private:
|
|||
do {
|
||||
// We can have either another nested array or a number literal.
|
||||
if (lexer.getCurToken() == '[') {
|
||||
values.push_back(ParseTensorLiteralExpr());
|
||||
values.push_back(parseTensorLiteralExpr());
|
||||
if (!values.back())
|
||||
return nullptr; // parse error in the nested array.
|
||||
} else {
|
||||
if (lexer.getCurToken() != tok_number)
|
||||
return parseError<ExprAST>("<num> or [", "in literal expression");
|
||||
values.push_back(ParseNumberExpr());
|
||||
values.push_back(parseNumberExpr());
|
||||
}
|
||||
|
||||
// End of this list on ']'
|
||||
|
@ -130,8 +130,10 @@ private:
|
|||
if (values.empty())
|
||||
return parseError<ExprAST>("<something>", "to fill literal expression");
|
||||
lexer.getNextToken(); // eat ]
|
||||
|
||||
/// Fill in the dimensions now. First the current nesting level:
|
||||
dims.push_back(values.size());
|
||||
|
||||
/// If there is any nested array, process all of them and ensure that
|
||||
/// dimensions are uniform.
|
||||
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
|
||||
|
@ -143,7 +145,7 @@ private:
|
|||
"inside literal expression");
|
||||
|
||||
// Append the nested dimensions to the current level
|
||||
auto &firstDims = firstLiteral->getDims();
|
||||
auto firstDims = firstLiteral->getDims();
|
||||
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
||||
|
||||
// Sanity check that shape is uniform across all elements of the list.
|
||||
|
@ -162,22 +164,22 @@ private:
|
|||
}
|
||||
|
||||
/// parenexpr ::= '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseParenExpr() {
|
||||
std::unique_ptr<ExprAST> parseParenExpr() {
|
||||
lexer.getNextToken(); // eat (.
|
||||
auto V = ParseExpression();
|
||||
if (!V)
|
||||
auto v = parseExpression();
|
||||
if (!v)
|
||||
return nullptr;
|
||||
|
||||
if (lexer.getCurToken() != ')')
|
||||
return parseError<ExprAST>(")", "to close expression with parentheses");
|
||||
lexer.consume(Token(')'));
|
||||
return V;
|
||||
return v;
|
||||
}
|
||||
|
||||
/// identifierexpr
|
||||
/// ::= identifier
|
||||
/// ::= identifier '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
|
||||
std::unique_ptr<ExprAST> parseIdentifierExpr() {
|
||||
std::string name = lexer.getId();
|
||||
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -188,11 +190,11 @@ private:
|
|||
|
||||
// This is a function call.
|
||||
lexer.consume(Token('('));
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
if (lexer.getCurToken() != ')') {
|
||||
while (true) {
|
||||
if (auto Arg = ParseExpression())
|
||||
Args.push_back(std::move(Arg));
|
||||
if (auto arg = parseExpression())
|
||||
args.push_back(std::move(arg));
|
||||
else
|
||||
return nullptr;
|
||||
|
||||
|
@ -208,14 +210,14 @@ private:
|
|||
|
||||
// It can be a builtin call to print
|
||||
if (name == "print") {
|
||||
if (Args.size() != 1)
|
||||
if (args.size() != 1)
|
||||
return parseError<ExprAST>("<single arg>", "as argument to print()");
|
||||
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
|
||||
}
|
||||
|
||||
// Call to a user-defined function
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
|
||||
}
|
||||
|
||||
/// primary
|
||||
|
@ -223,20 +225,20 @@ private:
|
|||
/// ::= numberexpr
|
||||
/// ::= parenexpr
|
||||
/// ::= tensorliteral
|
||||
std::unique_ptr<ExprAST> ParsePrimary() {
|
||||
std::unique_ptr<ExprAST> parsePrimary() {
|
||||
switch (lexer.getCurToken()) {
|
||||
default:
|
||||
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
||||
<< "' when expecting an expression\n";
|
||||
return nullptr;
|
||||
case tok_identifier:
|
||||
return ParseIdentifierExpr();
|
||||
return parseIdentifierExpr();
|
||||
case tok_number:
|
||||
return ParseNumberExpr();
|
||||
return parseNumberExpr();
|
||||
case '(':
|
||||
return ParseParenExpr();
|
||||
return parseParenExpr();
|
||||
case '[':
|
||||
return ParseTensorLiteralExpr();
|
||||
return parseTensorLiteralExpr();
|
||||
case ';':
|
||||
return nullptr;
|
||||
case '}':
|
||||
|
@ -248,54 +250,54 @@ private:
|
|||
/// argument indicates the precedence of the current binary operator.
|
||||
///
|
||||
/// binoprhs ::= ('+' primary)*
|
||||
std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
|
||||
std::unique_ptr<ExprAST> LHS) {
|
||||
std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
|
||||
std::unique_ptr<ExprAST> lhs) {
|
||||
// If this is a binop, find its precedence.
|
||||
while (true) {
|
||||
int TokPrec = GetTokPrecedence();
|
||||
int tokPrec = getTokPrecedence();
|
||||
|
||||
// If this is a binop that binds at least as tightly as the current binop,
|
||||
// consume it, otherwise we are done.
|
||||
if (TokPrec < ExprPrec)
|
||||
return LHS;
|
||||
if (tokPrec < exprPrec)
|
||||
return lhs;
|
||||
|
||||
// Okay, we know this is a binop.
|
||||
int BinOp = lexer.getCurToken();
|
||||
lexer.consume(Token(BinOp));
|
||||
int binOp = lexer.getCurToken();
|
||||
lexer.consume(Token(binOp));
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
||||
// Parse the primary expression after the binary operator.
|
||||
auto RHS = ParsePrimary();
|
||||
if (!RHS)
|
||||
auto rhs = parsePrimary();
|
||||
if (!rhs)
|
||||
return parseError<ExprAST>("expression", "to complete binary operator");
|
||||
|
||||
// If BinOp binds less tightly with RHS than the operator after RHS, let
|
||||
// the pending operator take RHS as its LHS.
|
||||
int NextPrec = GetTokPrecedence();
|
||||
if (TokPrec < NextPrec) {
|
||||
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
|
||||
if (!RHS)
|
||||
// If BinOp binds less tightly with rhs than the operator after rhs, let
|
||||
// the pending operator take rhs as its lhs.
|
||||
int nextPrec = getTokPrecedence();
|
||||
if (tokPrec < nextPrec) {
|
||||
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Merge LHS/RHS.
|
||||
LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
|
||||
std::move(LHS), std::move(RHS));
|
||||
// Merge lhs/RHS.
|
||||
lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
|
||||
std::move(lhs), std::move(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
/// expression::= primary binoprhs
|
||||
std::unique_ptr<ExprAST> ParseExpression() {
|
||||
auto LHS = ParsePrimary();
|
||||
if (!LHS)
|
||||
/// expression::= primary binop rhs
|
||||
std::unique_ptr<ExprAST> parseExpression() {
|
||||
auto lhs = parsePrimary();
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
|
||||
return ParseBinOpRHS(0, std::move(LHS));
|
||||
return parseBinOpRHS(0, std::move(lhs));
|
||||
}
|
||||
|
||||
/// type ::= < shape_list >
|
||||
/// shape_list ::= num | num , shape_list
|
||||
std::unique_ptr<VarType> ParseType() {
|
||||
std::unique_ptr<VarType> parseType() {
|
||||
if (lexer.getCurToken() != '<')
|
||||
return parseError<VarType>("<", "to begin type");
|
||||
lexer.getNextToken(); // eat <
|
||||
|
@ -319,7 +321,7 @@ private:
|
|||
/// and identifier and an optional type (shape specification) before the
|
||||
/// initializer.
|
||||
/// decl ::= var identifier [ type ] = expr
|
||||
std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
|
||||
std::unique_ptr<VarDeclExprAST> parseDeclaration() {
|
||||
if (lexer.getCurToken() != tok_var)
|
||||
return parseError<VarDeclExprAST>("var", "to begin declaration");
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -333,7 +335,7 @@ private:
|
|||
|
||||
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
|
||||
if (lexer.getCurToken() == '<') {
|
||||
type = ParseType();
|
||||
type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -341,7 +343,7 @@ private:
|
|||
if (!type)
|
||||
type = std::make_unique<VarType>();
|
||||
lexer.consume(Token('='));
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
|
||||
std::move(*type), std::move(expr));
|
||||
}
|
||||
|
@ -352,7 +354,7 @@ private:
|
|||
/// block ::= { expression_list }
|
||||
/// expression_list ::= block_expr ; expression_list
|
||||
/// block_expr ::= decl | "return" | expr
|
||||
std::unique_ptr<ExprASTList> ParseBlock() {
|
||||
std::unique_ptr<ExprASTList> parseBlock() {
|
||||
if (lexer.getCurToken() != '{')
|
||||
return parseError<ExprASTList>("{", "to begin block");
|
||||
lexer.consume(Token('{'));
|
||||
|
@ -366,19 +368,19 @@ private:
|
|||
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
|
||||
if (lexer.getCurToken() == tok_var) {
|
||||
// Variable declaration
|
||||
auto varDecl = ParseDeclaration();
|
||||
auto varDecl = parseDeclaration();
|
||||
if (!varDecl)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(varDecl));
|
||||
} else if (lexer.getCurToken() == tok_return) {
|
||||
// Return statement
|
||||
auto ret = ParseReturn();
|
||||
auto ret = parseReturn();
|
||||
if (!ret)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(ret));
|
||||
} else {
|
||||
// General expression
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(expr));
|
||||
|
@ -401,13 +403,13 @@ private:
|
|||
|
||||
/// prototype ::= def id '(' decl_list ')'
|
||||
/// decl_list ::= identifier | identifier, decl_list
|
||||
std::unique_ptr<PrototypeAST> ParsePrototype() {
|
||||
std::unique_ptr<PrototypeAST> parsePrototype() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_def);
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<PrototypeAST>("function name", "in prototype");
|
||||
|
||||
std::string FnName = lexer.getId();
|
||||
std::string fnName = lexer.getId();
|
||||
lexer.consume(tok_identifier);
|
||||
|
||||
if (lexer.getCurToken() != '(')
|
||||
|
@ -435,7 +437,7 @@ private:
|
|||
|
||||
// success.
|
||||
lexer.consume(Token(')'));
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), FnName,
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), fnName,
|
||||
std::move(args));
|
||||
}
|
||||
|
||||
|
@ -443,18 +445,18 @@ private:
|
|||
/// `def` keyword, followed by a block containing a list of expressions.
|
||||
///
|
||||
/// definition ::= prototype block
|
||||
std::unique_ptr<FunctionAST> ParseDefinition() {
|
||||
auto Proto = ParsePrototype();
|
||||
if (!Proto)
|
||||
std::unique_ptr<FunctionAST> parseDefinition() {
|
||||
auto proto = parsePrototype();
|
||||
if (!proto)
|
||||
return nullptr;
|
||||
|
||||
if (auto block = ParseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
|
||||
if (auto block = parseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Get the precedence of the pending binary operator token.
|
||||
int GetTokPrecedence() {
|
||||
int getTokPrecedence() {
|
||||
if (!isascii(lexer.getCurToken()))
|
||||
return -1;
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ private:
|
|||
mlir::ModuleOp theModule;
|
||||
|
||||
/// The builder is a helper class to create IR inside a function. The builder
|
||||
/// is stateful, in particular it keeeps an "insertion point": this is where
|
||||
/// is stateful, in particular it keeps an "insertion point": this is where
|
||||
/// the next operations will be introduced.
|
||||
mlir::OpBuilder builder;
|
||||
|
||||
|
@ -143,7 +143,7 @@ private:
|
|||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "toy/AST.h"
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -40,22 +41,22 @@ struct Indent {
|
|||
/// the way. The only data member is the current indentation level.
|
||||
class ASTDumper {
|
||||
public:
|
||||
void dump(ModuleAST *Node);
|
||||
void dump(ModuleAST *node);
|
||||
|
||||
private:
|
||||
void dump(VarType &type);
|
||||
void dump(const VarType &type);
|
||||
void dump(VarDeclExprAST *varDecl);
|
||||
void dump(ExprAST *expr);
|
||||
void dump(ExprASTList *exprList);
|
||||
void dump(NumberExprAST *num);
|
||||
void dump(LiteralExprAST *Node);
|
||||
void dump(VariableExprAST *Node);
|
||||
void dump(ReturnExprAST *Node);
|
||||
void dump(BinaryExprAST *Node);
|
||||
void dump(CallExprAST *Node);
|
||||
void dump(PrintExprAST *Node);
|
||||
void dump(PrototypeAST *Node);
|
||||
void dump(FunctionAST *Node);
|
||||
void dump(LiteralExprAST *node);
|
||||
void dump(VariableExprAST *node);
|
||||
void dump(ReturnExprAST *node);
|
||||
void dump(BinaryExprAST *node);
|
||||
void dump(CallExprAST *node);
|
||||
void dump(PrintExprAST *node);
|
||||
void dump(PrototypeAST *node);
|
||||
void dump(FunctionAST *node);
|
||||
|
||||
// Actually print spaces matching the current indentation level
|
||||
void indent() {
|
||||
|
@ -68,8 +69,8 @@ private:
|
|||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *Node) {
|
||||
const auto &loc = Node->loc();
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
.str();
|
||||
|
@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
|
|||
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||
/// We print out such array with the dimensions spelled out at every level:
|
||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||
void printLitHelper(ExprAST *lit_or_num) {
|
||||
void printLitHelper(ExprAST *litOrNum) {
|
||||
// Inside a literal expression we can have either a number or another literal
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||
llvm::errs() << num->getValue();
|
||||
return;
|
||||
}
|
||||
auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
|
||||
auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
|
||||
|
||||
// Print the dimension for this literal first
|
||||
llvm::errs() << "<";
|
||||
{
|
||||
const char *sep = "";
|
||||
for (auto dim : literal->getDims()) {
|
||||
llvm::errs() << sep << dim;
|
||||
sep = ", ";
|
||||
}
|
||||
}
|
||||
mlir::interleaveComma(literal->getDims(), llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
|
||||
// Now print the content, recursing on every element of the list
|
||||
llvm::errs() << "[ ";
|
||||
const char *sep = "";
|
||||
for (auto &elt : literal->getValues()) {
|
||||
llvm::errs() << sep;
|
||||
printLitHelper(elt.get());
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(literal->getValues(), llvm::errs(),
|
||||
[&](auto &elt) { printLitHelper(elt.get()); });
|
||||
llvm::errs() << "]";
|
||||
}
|
||||
|
||||
/// Print a literal, see the recursive helper above for the implementation.
|
||||
void ASTDumper::dump(LiteralExprAST *Node) {
|
||||
void ASTDumper::dump(LiteralExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Literal: ";
|
||||
printLitHelper(Node);
|
||||
llvm::errs() << " " << loc(Node) << "\n";
|
||||
printLitHelper(node);
|
||||
llvm::errs() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Print a variable reference (just a name).
|
||||
void ASTDumper::dump(VariableExprAST *Node) {
|
||||
void ASTDumper::dump(VariableExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
|
||||
llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Return statement print the return and its (optional) argument.
|
||||
void ASTDumper::dump(ReturnExprAST *Node) {
|
||||
void ASTDumper::dump(ReturnExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Return\n";
|
||||
if (Node->getExpr().hasValue())
|
||||
return dump(*Node->getExpr());
|
||||
if (node->getExpr().hasValue())
|
||||
return dump(*node->getExpr());
|
||||
{
|
||||
INDENT();
|
||||
llvm::errs() << "(void)\n";
|
||||
|
@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
|
|||
}
|
||||
|
||||
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
|
||||
void ASTDumper::dump(BinaryExprAST *Node) {
|
||||
void ASTDumper::dump(BinaryExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
|
||||
dump(Node->getLHS());
|
||||
dump(Node->getRHS());
|
||||
llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
|
||||
dump(node->getLHS());
|
||||
dump(node->getRHS());
|
||||
}
|
||||
|
||||
/// Print a call expression, first the callee name and the list of args by
|
||||
/// recursing into each individual argument.
|
||||
void ASTDumper::dump(CallExprAST *Node) {
|
||||
void ASTDumper::dump(CallExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
|
||||
for (auto &arg : Node->getArgs())
|
||||
llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
|
||||
for (auto &arg : node->getArgs())
|
||||
dump(arg.get());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a builtin print call, first the builtin name and then the argument.
|
||||
void ASTDumper::dump(PrintExprAST *Node) {
|
||||
void ASTDumper::dump(PrintExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Print [ " << loc(Node) << "\n";
|
||||
dump(Node->getArg());
|
||||
llvm::errs() << "Print [ " << loc(node) << "\n";
|
||||
dump(node->getArg());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print type: only the shape is printed in between '<' and '>'
|
||||
void ASTDumper::dump(VarType &type) {
|
||||
void ASTDumper::dump(const VarType &type) {
|
||||
llvm::errs() << "<";
|
||||
const char *sep = "";
|
||||
for (auto shape : type.shape) {
|
||||
llvm::errs() << sep << shape;
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(type.shape, llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
}
|
||||
|
||||
/// Print a function prototype, first the function name, and then the list of
|
||||
/// parameters names.
|
||||
void ASTDumper::dump(PrototypeAST *Node) {
|
||||
void ASTDumper::dump(PrototypeAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
|
||||
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
|
||||
indent();
|
||||
llvm::errs() << "Params: [";
|
||||
const char *sep = "";
|
||||
for (auto &arg : Node->getArgs()) {
|
||||
llvm::errs() << sep << arg->getName();
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(node->getArgs(), llvm::errs(),
|
||||
[](auto &arg) { llvm::errs() << arg->getName(); });
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a function, first the prototype and then the body.
|
||||
void ASTDumper::dump(FunctionAST *Node) {
|
||||
void ASTDumper::dump(FunctionAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Function \n";
|
||||
dump(Node->getProto());
|
||||
dump(Node->getBody());
|
||||
dump(node->getProto());
|
||||
dump(node->getBody());
|
||||
}
|
||||
|
||||
/// Print a module, actually loop over the functions and print them in sequence.
|
||||
void ASTDumper::dump(ModuleAST *Node) {
|
||||
void ASTDumper::dump(ModuleAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Module:\n";
|
||||
for (auto &F : *Node)
|
||||
dump(&F);
|
||||
for (auto &f : *node)
|
||||
dump(&f);
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
|
|
@ -64,20 +64,20 @@ static cl::opt<enum Action> emitAction(
|
|||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
|
||||
|
||||
static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
|
||||
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
|
||||
|
||||
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
||||
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||
if (std::error_code EC = FileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
if (std::error_code ec = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
auto buffer = FileOrErr.get()->getBuffer();
|
||||
auto buffer = fileOrErr.get()->getBuffer();
|
||||
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
|
||||
Parser parser(lexer);
|
||||
return parser.ParseModule();
|
||||
return parser.parseModule();
|
||||
}
|
||||
|
||||
int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
||||
|
@ -119,7 +119,7 @@ int dumpMLIR() {
|
|||
if (int error = loadMLIR(sourceMgr, context, module))
|
||||
return error;
|
||||
|
||||
if (EnableOpt) {
|
||||
if (enableOpt) {
|
||||
mlir::PassManager pm(&context);
|
||||
// Apply any generic pass manager command line options and run the pipeline.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
|
|
@ -54,7 +54,6 @@ public:
|
|||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
: kind(kind), location(location) {}
|
||||
|
||||
virtual ~ExprAST() = default;
|
||||
|
||||
ExprASTKind getKind() const { return kind; }
|
||||
|
@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
|
|||
double Val;
|
||||
|
||||
public:
|
||||
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
|
||||
NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
|
||||
|
||||
double getValue() { return Val; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
/// Expression class for a literal value.
|
||||
|
@ -93,10 +92,11 @@ public:
|
|||
: ExprAST(Expr_Literal, loc), values(std::move(values)),
|
||||
dims(std::move(dims)) {}
|
||||
|
||||
std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
|
||||
std::vector<int64_t> &getDims() { return dims; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
|
||||
llvm::ArrayRef<int64_t> getDims() { return dims; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
|
||||
};
|
||||
|
||||
/// Expression class for referencing a variable, like "a".
|
||||
|
@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
|
|||
std::string name;
|
||||
|
||||
public:
|
||||
VariableExprAST(Location loc, const std::string &name)
|
||||
VariableExprAST(Location loc, llvm::StringRef name)
|
||||
: ExprAST(Expr_Var, loc), name(name) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
/// Expression class for defining a variable.
|
||||
|
@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
|
|||
std::unique_ptr<ExprAST> initVal;
|
||||
|
||||
public:
|
||||
VarDeclExprAST(Location loc, const std::string &name, VarType type,
|
||||
VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
|
||||
std::unique_ptr<ExprAST> initVal)
|
||||
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
|
||||
initVal(std::move(initVal)) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
ExprAST *getInitVal() { return initVal.get(); }
|
||||
VarType &getType() { return type; }
|
||||
const VarType &getType() { return type; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
/// Expression class for a return operator.
|
||||
|
@ -144,61 +144,61 @@ public:
|
|||
llvm::Optional<ExprAST *> getExpr() {
|
||||
if (expr.hasValue())
|
||||
return expr->get();
|
||||
return llvm::NoneType();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
|
||||
};
|
||||
|
||||
/// Expression class for a binary operator.
|
||||
class BinaryExprAST : public ExprAST {
|
||||
char Op;
|
||||
std::unique_ptr<ExprAST> LHS, RHS;
|
||||
char op;
|
||||
std::unique_ptr<ExprAST> lhs, rhs;
|
||||
|
||||
public:
|
||||
char getOp() { return Op; }
|
||||
ExprAST *getLHS() { return LHS.get(); }
|
||||
ExprAST *getRHS() { return RHS.get(); }
|
||||
char getOp() { return op; }
|
||||
ExprAST *getLHS() { return lhs.get(); }
|
||||
ExprAST *getRHS() { return rhs.get(); }
|
||||
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
|
||||
std::unique_ptr<ExprAST> RHS)
|
||||
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
|
||||
RHS(std::move(RHS)) {}
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
|
||||
std::unique_ptr<ExprAST> rhs)
|
||||
: ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
|
||||
rhs(std::move(rhs)) {}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
|
||||
};
|
||||
|
||||
/// Expression class for function calls.
|
||||
class CallExprAST : public ExprAST {
|
||||
std::string Callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::string callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
|
||||
public:
|
||||
CallExprAST(Location loc, const std::string &Callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> Args)
|
||||
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
|
||||
CallExprAST(Location loc, const std::string &callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> args)
|
||||
: ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
|
||||
|
||||
llvm::StringRef getCallee() { return Callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
|
||||
llvm::StringRef getCallee() { return callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
|
||||
};
|
||||
|
||||
/// Expression class for builtin print calls.
|
||||
class PrintExprAST : public ExprAST {
|
||||
std::unique_ptr<ExprAST> Arg;
|
||||
std::unique_ptr<ExprAST> arg;
|
||||
|
||||
public:
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
|
||||
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
|
||||
: ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
|
||||
|
||||
ExprAST *getArg() { return Arg.get(); }
|
||||
ExprAST *getArg() { return arg.get(); }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
|
||||
};
|
||||
|
||||
/// This class represents the "prototype" for a function, which captures its
|
||||
|
@ -215,23 +215,21 @@ public:
|
|||
: location(location), name(name), args(std::move(args)) {}
|
||||
|
||||
const Location &loc() { return location; }
|
||||
const std::string &getName() const { return name; }
|
||||
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
|
||||
return args;
|
||||
}
|
||||
llvm::StringRef getName() const { return name; }
|
||||
llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
|
||||
};
|
||||
|
||||
/// This class represents a function definition itself.
|
||||
class FunctionAST {
|
||||
std::unique_ptr<PrototypeAST> Proto;
|
||||
std::unique_ptr<ExprASTList> Body;
|
||||
std::unique_ptr<PrototypeAST> proto;
|
||||
std::unique_ptr<ExprASTList> body;
|
||||
|
||||
public:
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
|
||||
std::unique_ptr<ExprASTList> Body)
|
||||
: Proto(std::move(Proto)), Body(std::move(Body)) {}
|
||||
PrototypeAST *getProto() { return Proto.get(); }
|
||||
ExprASTList *getBody() { return Body.get(); }
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> proto,
|
||||
std::unique_ptr<ExprASTList> body)
|
||||
: proto(std::move(proto)), body(std::move(body)) {}
|
||||
PrototypeAST *getProto() { return proto.get(); }
|
||||
ExprASTList *getBody() { return body.get(); }
|
||||
};
|
||||
|
||||
/// This class represents a list of functions to be processed together
|
||||
|
|
|
@ -89,13 +89,13 @@ public:
|
|||
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
|
||||
llvm::StringRef getId() {
|
||||
assert(curTok == tok_identifier);
|
||||
return IdentifierStr;
|
||||
return identifierStr;
|
||||
}
|
||||
|
||||
/// Return the current number (prereq: getCurToken() == tok_number)
|
||||
double getValue() {
|
||||
assert(curTok == tok_number);
|
||||
return NumVal;
|
||||
return numVal;
|
||||
}
|
||||
|
||||
/// Return the location for the beginning of the current token.
|
||||
|
@ -135,56 +135,58 @@ private:
|
|||
/// Return the next token from standard input.
|
||||
Token getTok() {
|
||||
// Skip any whitespace.
|
||||
while (isspace(LastChar))
|
||||
LastChar = Token(getNextChar());
|
||||
while (isspace(lastChar))
|
||||
lastChar = Token(getNextChar());
|
||||
|
||||
// Save the current location before reading the token characters.
|
||||
lastLocation.line = curLineNum;
|
||||
lastLocation.col = curCol;
|
||||
|
||||
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
IdentifierStr = (char)LastChar;
|
||||
while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
|
||||
IdentifierStr += (char)LastChar;
|
||||
// Identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
if (isalpha(lastChar)) {
|
||||
identifierStr = (char)lastChar;
|
||||
while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
|
||||
identifierStr += (char)lastChar;
|
||||
|
||||
if (IdentifierStr == "return")
|
||||
if (identifierStr == "return")
|
||||
return tok_return;
|
||||
if (IdentifierStr == "def")
|
||||
if (identifierStr == "def")
|
||||
return tok_def;
|
||||
if (IdentifierStr == "var")
|
||||
if (identifierStr == "var")
|
||||
return tok_var;
|
||||
return tok_identifier;
|
||||
}
|
||||
|
||||
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
|
||||
std::string NumStr;
|
||||
// Number: [0-9.]+
|
||||
if (isdigit(lastChar) || lastChar == '.') {
|
||||
std::string numStr;
|
||||
do {
|
||||
NumStr += LastChar;
|
||||
LastChar = Token(getNextChar());
|
||||
} while (isdigit(LastChar) || LastChar == '.');
|
||||
numStr += lastChar;
|
||||
lastChar = Token(getNextChar());
|
||||
} while (isdigit(lastChar) || lastChar == '.');
|
||||
|
||||
NumVal = strtod(NumStr.c_str(), nullptr);
|
||||
numVal = strtod(numStr.c_str(), nullptr);
|
||||
return tok_number;
|
||||
}
|
||||
|
||||
if (LastChar == '#') {
|
||||
if (lastChar == '#') {
|
||||
// Comment until end of line.
|
||||
do
|
||||
LastChar = Token(getNextChar());
|
||||
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
|
||||
do {
|
||||
lastChar = Token(getNextChar());
|
||||
} while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
|
||||
|
||||
if (LastChar != EOF)
|
||||
if (lastChar != EOF)
|
||||
return getTok();
|
||||
}
|
||||
|
||||
// Check for end of file. Don't eat the EOF.
|
||||
if (LastChar == EOF)
|
||||
if (lastChar == EOF)
|
||||
return tok_eof;
|
||||
|
||||
// Otherwise, just return the character as its ascii value.
|
||||
Token ThisChar = Token(LastChar);
|
||||
LastChar = Token(getNextChar());
|
||||
return ThisChar;
|
||||
Token thisChar = Token(lastChar);
|
||||
lastChar = Token(getNextChar());
|
||||
return thisChar;
|
||||
}
|
||||
|
||||
/// The last token read from the input.
|
||||
|
@ -194,15 +196,15 @@ private:
|
|||
Location lastLocation;
|
||||
|
||||
/// If the current Token is an identifier, this string contains the value.
|
||||
std::string IdentifierStr;
|
||||
std::string identifierStr;
|
||||
|
||||
/// If the current Token is a number, this contains the value.
|
||||
double NumVal = 0;
|
||||
double numVal = 0;
|
||||
|
||||
/// The last value returned by getNextChar(). We need to keep it around as we
|
||||
/// always need to read ahead one character to decide when to end a token and
|
||||
/// we can't put it back in the stream after reading from it.
|
||||
Token LastChar = Token(' ');
|
||||
Token lastChar = Token(' ');
|
||||
|
||||
/// Keep track of the current line number in the input stream
|
||||
int curLineNum = 0;
|
||||
|
|
|
@ -48,13 +48,13 @@ public:
|
|||
Parser(Lexer &lexer) : lexer(lexer) {}
|
||||
|
||||
/// Parse a full Module. A module is a list of function definitions.
|
||||
std::unique_ptr<ModuleAST> ParseModule() {
|
||||
std::unique_ptr<ModuleAST> parseModule() {
|
||||
lexer.getNextToken(); // prime the lexer
|
||||
|
||||
// Parse functions one at a time and accumulate in this vector.
|
||||
std::vector<FunctionAST> functions;
|
||||
while (auto F = ParseDefinition()) {
|
||||
functions.push_back(std::move(*F));
|
||||
while (auto f = parseDefinition()) {
|
||||
functions.push_back(std::move(*f));
|
||||
if (lexer.getCurToken() == tok_eof)
|
||||
break;
|
||||
}
|
||||
|
@ -70,14 +70,14 @@ private:
|
|||
|
||||
/// Parse a return statement.
|
||||
/// return :== return ; | return expr ;
|
||||
std::unique_ptr<ReturnExprAST> ParseReturn() {
|
||||
std::unique_ptr<ReturnExprAST> parseReturn() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_return);
|
||||
|
||||
// return takes an optional argument
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
if (lexer.getCurToken() != ';') {
|
||||
expr = ParseExpression();
|
||||
expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -86,18 +86,18 @@ private:
|
|||
|
||||
/// Parse a literal number.
|
||||
/// numberexpr ::= number
|
||||
std::unique_ptr<ExprAST> ParseNumberExpr() {
|
||||
std::unique_ptr<ExprAST> parseNumberExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
auto Result =
|
||||
auto result =
|
||||
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
|
||||
lexer.consume(tok_number);
|
||||
return std::move(Result);
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/// Parse a literal array expression.
|
||||
/// tensorLiteral ::= [ literalList ] | number
|
||||
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
||||
std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
|
||||
std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(Token('['));
|
||||
|
||||
|
@ -108,13 +108,13 @@ private:
|
|||
do {
|
||||
// We can have either another nested array or a number literal.
|
||||
if (lexer.getCurToken() == '[') {
|
||||
values.push_back(ParseTensorLiteralExpr());
|
||||
values.push_back(parseTensorLiteralExpr());
|
||||
if (!values.back())
|
||||
return nullptr; // parse error in the nested array.
|
||||
} else {
|
||||
if (lexer.getCurToken() != tok_number)
|
||||
return parseError<ExprAST>("<num> or [", "in literal expression");
|
||||
values.push_back(ParseNumberExpr());
|
||||
values.push_back(parseNumberExpr());
|
||||
}
|
||||
|
||||
// End of this list on ']'
|
||||
|
@ -130,8 +130,10 @@ private:
|
|||
if (values.empty())
|
||||
return parseError<ExprAST>("<something>", "to fill literal expression");
|
||||
lexer.getNextToken(); // eat ]
|
||||
|
||||
/// Fill in the dimensions now. First the current nesting level:
|
||||
dims.push_back(values.size());
|
||||
|
||||
/// If there is any nested array, process all of them and ensure that
|
||||
/// dimensions are uniform.
|
||||
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
|
||||
|
@ -143,7 +145,7 @@ private:
|
|||
"inside literal expression");
|
||||
|
||||
// Append the nested dimensions to the current level
|
||||
auto &firstDims = firstLiteral->getDims();
|
||||
auto firstDims = firstLiteral->getDims();
|
||||
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
||||
|
||||
// Sanity check that shape is uniform across all elements of the list.
|
||||
|
@ -162,22 +164,22 @@ private:
|
|||
}
|
||||
|
||||
/// parenexpr ::= '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseParenExpr() {
|
||||
std::unique_ptr<ExprAST> parseParenExpr() {
|
||||
lexer.getNextToken(); // eat (.
|
||||
auto V = ParseExpression();
|
||||
if (!V)
|
||||
auto v = parseExpression();
|
||||
if (!v)
|
||||
return nullptr;
|
||||
|
||||
if (lexer.getCurToken() != ')')
|
||||
return parseError<ExprAST>(")", "to close expression with parentheses");
|
||||
lexer.consume(Token(')'));
|
||||
return V;
|
||||
return v;
|
||||
}
|
||||
|
||||
/// identifierexpr
|
||||
/// ::= identifier
|
||||
/// ::= identifier '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
|
||||
std::unique_ptr<ExprAST> parseIdentifierExpr() {
|
||||
std::string name = lexer.getId();
|
||||
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -188,11 +190,11 @@ private:
|
|||
|
||||
// This is a function call.
|
||||
lexer.consume(Token('('));
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
if (lexer.getCurToken() != ')') {
|
||||
while (true) {
|
||||
if (auto Arg = ParseExpression())
|
||||
Args.push_back(std::move(Arg));
|
||||
if (auto arg = parseExpression())
|
||||
args.push_back(std::move(arg));
|
||||
else
|
||||
return nullptr;
|
||||
|
||||
|
@ -208,14 +210,14 @@ private:
|
|||
|
||||
// It can be a builtin call to print
|
||||
if (name == "print") {
|
||||
if (Args.size() != 1)
|
||||
if (args.size() != 1)
|
||||
return parseError<ExprAST>("<single arg>", "as argument to print()");
|
||||
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
|
||||
}
|
||||
|
||||
// Call to a user-defined function
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
|
||||
}
|
||||
|
||||
/// primary
|
||||
|
@ -223,20 +225,20 @@ private:
|
|||
/// ::= numberexpr
|
||||
/// ::= parenexpr
|
||||
/// ::= tensorliteral
|
||||
std::unique_ptr<ExprAST> ParsePrimary() {
|
||||
std::unique_ptr<ExprAST> parsePrimary() {
|
||||
switch (lexer.getCurToken()) {
|
||||
default:
|
||||
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
||||
<< "' when expecting an expression\n";
|
||||
return nullptr;
|
||||
case tok_identifier:
|
||||
return ParseIdentifierExpr();
|
||||
return parseIdentifierExpr();
|
||||
case tok_number:
|
||||
return ParseNumberExpr();
|
||||
return parseNumberExpr();
|
||||
case '(':
|
||||
return ParseParenExpr();
|
||||
return parseParenExpr();
|
||||
case '[':
|
||||
return ParseTensorLiteralExpr();
|
||||
return parseTensorLiteralExpr();
|
||||
case ';':
|
||||
return nullptr;
|
||||
case '}':
|
||||
|
@ -248,54 +250,54 @@ private:
|
|||
/// argument indicates the precedence of the current binary operator.
|
||||
///
|
||||
/// binoprhs ::= ('+' primary)*
|
||||
std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
|
||||
std::unique_ptr<ExprAST> LHS) {
|
||||
std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
|
||||
std::unique_ptr<ExprAST> lhs) {
|
||||
// If this is a binop, find its precedence.
|
||||
while (true) {
|
||||
int TokPrec = GetTokPrecedence();
|
||||
int tokPrec = getTokPrecedence();
|
||||
|
||||
// If this is a binop that binds at least as tightly as the current binop,
|
||||
// consume it, otherwise we are done.
|
||||
if (TokPrec < ExprPrec)
|
||||
return LHS;
|
||||
if (tokPrec < exprPrec)
|
||||
return lhs;
|
||||
|
||||
// Okay, we know this is a binop.
|
||||
int BinOp = lexer.getCurToken();
|
||||
lexer.consume(Token(BinOp));
|
||||
int binOp = lexer.getCurToken();
|
||||
lexer.consume(Token(binOp));
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
||||
// Parse the primary expression after the binary operator.
|
||||
auto RHS = ParsePrimary();
|
||||
if (!RHS)
|
||||
auto rhs = parsePrimary();
|
||||
if (!rhs)
|
||||
return parseError<ExprAST>("expression", "to complete binary operator");
|
||||
|
||||
// If BinOp binds less tightly with RHS than the operator after RHS, let
|
||||
// the pending operator take RHS as its LHS.
|
||||
int NextPrec = GetTokPrecedence();
|
||||
if (TokPrec < NextPrec) {
|
||||
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
|
||||
if (!RHS)
|
||||
// If BinOp binds less tightly with rhs than the operator after rhs, let
|
||||
// the pending operator take rhs as its lhs.
|
||||
int nextPrec = getTokPrecedence();
|
||||
if (tokPrec < nextPrec) {
|
||||
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Merge LHS/RHS.
|
||||
LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
|
||||
std::move(LHS), std::move(RHS));
|
||||
// Merge lhs/RHS.
|
||||
lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
|
||||
std::move(lhs), std::move(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
/// expression::= primary binoprhs
|
||||
std::unique_ptr<ExprAST> ParseExpression() {
|
||||
auto LHS = ParsePrimary();
|
||||
if (!LHS)
|
||||
/// expression::= primary binop rhs
|
||||
std::unique_ptr<ExprAST> parseExpression() {
|
||||
auto lhs = parsePrimary();
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
|
||||
return ParseBinOpRHS(0, std::move(LHS));
|
||||
return parseBinOpRHS(0, std::move(lhs));
|
||||
}
|
||||
|
||||
/// type ::= < shape_list >
|
||||
/// shape_list ::= num | num , shape_list
|
||||
std::unique_ptr<VarType> ParseType() {
|
||||
std::unique_ptr<VarType> parseType() {
|
||||
if (lexer.getCurToken() != '<')
|
||||
return parseError<VarType>("<", "to begin type");
|
||||
lexer.getNextToken(); // eat <
|
||||
|
@ -319,7 +321,7 @@ private:
|
|||
/// and identifier and an optional type (shape specification) before the
|
||||
/// initializer.
|
||||
/// decl ::= var identifier [ type ] = expr
|
||||
std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
|
||||
std::unique_ptr<VarDeclExprAST> parseDeclaration() {
|
||||
if (lexer.getCurToken() != tok_var)
|
||||
return parseError<VarDeclExprAST>("var", "to begin declaration");
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -333,7 +335,7 @@ private:
|
|||
|
||||
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
|
||||
if (lexer.getCurToken() == '<') {
|
||||
type = ParseType();
|
||||
type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -341,7 +343,7 @@ private:
|
|||
if (!type)
|
||||
type = std::make_unique<VarType>();
|
||||
lexer.consume(Token('='));
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
|
||||
std::move(*type), std::move(expr));
|
||||
}
|
||||
|
@ -352,7 +354,7 @@ private:
|
|||
/// block ::= { expression_list }
|
||||
/// expression_list ::= block_expr ; expression_list
|
||||
/// block_expr ::= decl | "return" | expr
|
||||
std::unique_ptr<ExprASTList> ParseBlock() {
|
||||
std::unique_ptr<ExprASTList> parseBlock() {
|
||||
if (lexer.getCurToken() != '{')
|
||||
return parseError<ExprASTList>("{", "to begin block");
|
||||
lexer.consume(Token('{'));
|
||||
|
@ -366,19 +368,19 @@ private:
|
|||
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
|
||||
if (lexer.getCurToken() == tok_var) {
|
||||
// Variable declaration
|
||||
auto varDecl = ParseDeclaration();
|
||||
auto varDecl = parseDeclaration();
|
||||
if (!varDecl)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(varDecl));
|
||||
} else if (lexer.getCurToken() == tok_return) {
|
||||
// Return statement
|
||||
auto ret = ParseReturn();
|
||||
auto ret = parseReturn();
|
||||
if (!ret)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(ret));
|
||||
} else {
|
||||
// General expression
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(expr));
|
||||
|
@ -401,13 +403,13 @@ private:
|
|||
|
||||
/// prototype ::= def id '(' decl_list ')'
|
||||
/// decl_list ::= identifier | identifier, decl_list
|
||||
std::unique_ptr<PrototypeAST> ParsePrototype() {
|
||||
std::unique_ptr<PrototypeAST> parsePrototype() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_def);
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<PrototypeAST>("function name", "in prototype");
|
||||
|
||||
std::string FnName = lexer.getId();
|
||||
std::string fnName = lexer.getId();
|
||||
lexer.consume(tok_identifier);
|
||||
|
||||
if (lexer.getCurToken() != '(')
|
||||
|
@ -435,7 +437,7 @@ private:
|
|||
|
||||
// success.
|
||||
lexer.consume(Token(')'));
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), FnName,
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), fnName,
|
||||
std::move(args));
|
||||
}
|
||||
|
||||
|
@ -443,18 +445,18 @@ private:
|
|||
/// `def` keyword, followed by a block containing a list of expressions.
|
||||
///
|
||||
/// definition ::= prototype block
|
||||
std::unique_ptr<FunctionAST> ParseDefinition() {
|
||||
auto Proto = ParsePrototype();
|
||||
if (!Proto)
|
||||
std::unique_ptr<FunctionAST> parseDefinition() {
|
||||
auto proto = parsePrototype();
|
||||
if (!proto)
|
||||
return nullptr;
|
||||
|
||||
if (auto block = ParseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
|
||||
if (auto block = parseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Get the precedence of the pending binary operator token.
|
||||
int GetTokPrecedence() {
|
||||
int getTokPrecedence() {
|
||||
if (!isascii(lexer.getCurToken()))
|
||||
return -1;
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ private:
|
|||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "toy/AST.h"
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -40,22 +41,22 @@ struct Indent {
|
|||
/// the way. The only data member is the current indentation level.
|
||||
class ASTDumper {
|
||||
public:
|
||||
void dump(ModuleAST *Node);
|
||||
void dump(ModuleAST *node);
|
||||
|
||||
private:
|
||||
void dump(VarType &type);
|
||||
void dump(const VarType &type);
|
||||
void dump(VarDeclExprAST *varDecl);
|
||||
void dump(ExprAST *expr);
|
||||
void dump(ExprASTList *exprList);
|
||||
void dump(NumberExprAST *num);
|
||||
void dump(LiteralExprAST *Node);
|
||||
void dump(VariableExprAST *Node);
|
||||
void dump(ReturnExprAST *Node);
|
||||
void dump(BinaryExprAST *Node);
|
||||
void dump(CallExprAST *Node);
|
||||
void dump(PrintExprAST *Node);
|
||||
void dump(PrototypeAST *Node);
|
||||
void dump(FunctionAST *Node);
|
||||
void dump(LiteralExprAST *node);
|
||||
void dump(VariableExprAST *node);
|
||||
void dump(ReturnExprAST *node);
|
||||
void dump(BinaryExprAST *node);
|
||||
void dump(CallExprAST *node);
|
||||
void dump(PrintExprAST *node);
|
||||
void dump(PrototypeAST *node);
|
||||
void dump(FunctionAST *node);
|
||||
|
||||
// Actually print spaces matching the current indentation level
|
||||
void indent() {
|
||||
|
@ -68,8 +69,8 @@ private:
|
|||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *Node) {
|
||||
const auto &loc = Node->loc();
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
.str();
|
||||
|
@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
|
|||
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||
/// We print out such array with the dimensions spelled out at every level:
|
||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||
void printLitHelper(ExprAST *lit_or_num) {
|
||||
void printLitHelper(ExprAST *litOrNum) {
|
||||
// Inside a literal expression we can have either a number or another literal
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||
llvm::errs() << num->getValue();
|
||||
return;
|
||||
}
|
||||
auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
|
||||
auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
|
||||
|
||||
// Print the dimension for this literal first
|
||||
llvm::errs() << "<";
|
||||
{
|
||||
const char *sep = "";
|
||||
for (auto dim : literal->getDims()) {
|
||||
llvm::errs() << sep << dim;
|
||||
sep = ", ";
|
||||
}
|
||||
}
|
||||
mlir::interleaveComma(literal->getDims(), llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
|
||||
// Now print the content, recursing on every element of the list
|
||||
llvm::errs() << "[ ";
|
||||
const char *sep = "";
|
||||
for (auto &elt : literal->getValues()) {
|
||||
llvm::errs() << sep;
|
||||
printLitHelper(elt.get());
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(literal->getValues(), llvm::errs(),
|
||||
[&](auto &elt) { printLitHelper(elt.get()); });
|
||||
llvm::errs() << "]";
|
||||
}
|
||||
|
||||
/// Print a literal, see the recursive helper above for the implementation.
|
||||
void ASTDumper::dump(LiteralExprAST *Node) {
|
||||
void ASTDumper::dump(LiteralExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Literal: ";
|
||||
printLitHelper(Node);
|
||||
llvm::errs() << " " << loc(Node) << "\n";
|
||||
printLitHelper(node);
|
||||
llvm::errs() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Print a variable reference (just a name).
|
||||
void ASTDumper::dump(VariableExprAST *Node) {
|
||||
void ASTDumper::dump(VariableExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
|
||||
llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Return statement print the return and its (optional) argument.
|
||||
void ASTDumper::dump(ReturnExprAST *Node) {
|
||||
void ASTDumper::dump(ReturnExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Return\n";
|
||||
if (Node->getExpr().hasValue())
|
||||
return dump(*Node->getExpr());
|
||||
if (node->getExpr().hasValue())
|
||||
return dump(*node->getExpr());
|
||||
{
|
||||
INDENT();
|
||||
llvm::errs() << "(void)\n";
|
||||
|
@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
|
|||
}
|
||||
|
||||
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
|
||||
void ASTDumper::dump(BinaryExprAST *Node) {
|
||||
void ASTDumper::dump(BinaryExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
|
||||
dump(Node->getLHS());
|
||||
dump(Node->getRHS());
|
||||
llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
|
||||
dump(node->getLHS());
|
||||
dump(node->getRHS());
|
||||
}
|
||||
|
||||
/// Print a call expression, first the callee name and the list of args by
|
||||
/// recursing into each individual argument.
|
||||
void ASTDumper::dump(CallExprAST *Node) {
|
||||
void ASTDumper::dump(CallExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
|
||||
for (auto &arg : Node->getArgs())
|
||||
llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
|
||||
for (auto &arg : node->getArgs())
|
||||
dump(arg.get());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a builtin print call, first the builtin name and then the argument.
|
||||
void ASTDumper::dump(PrintExprAST *Node) {
|
||||
void ASTDumper::dump(PrintExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Print [ " << loc(Node) << "\n";
|
||||
dump(Node->getArg());
|
||||
llvm::errs() << "Print [ " << loc(node) << "\n";
|
||||
dump(node->getArg());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print type: only the shape is printed in between '<' and '>'
|
||||
void ASTDumper::dump(VarType &type) {
|
||||
void ASTDumper::dump(const VarType &type) {
|
||||
llvm::errs() << "<";
|
||||
const char *sep = "";
|
||||
for (auto shape : type.shape) {
|
||||
llvm::errs() << sep << shape;
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(type.shape, llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
}
|
||||
|
||||
/// Print a function prototype, first the function name, and then the list of
|
||||
/// parameters names.
|
||||
void ASTDumper::dump(PrototypeAST *Node) {
|
||||
void ASTDumper::dump(PrototypeAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
|
||||
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
|
||||
indent();
|
||||
llvm::errs() << "Params: [";
|
||||
const char *sep = "";
|
||||
for (auto &arg : Node->getArgs()) {
|
||||
llvm::errs() << sep << arg->getName();
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(node->getArgs(), llvm::errs(),
|
||||
[](auto &arg) { llvm::errs() << arg->getName(); });
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a function, first the prototype and then the body.
|
||||
void ASTDumper::dump(FunctionAST *Node) {
|
||||
void ASTDumper::dump(FunctionAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Function \n";
|
||||
dump(Node->getProto());
|
||||
dump(Node->getBody());
|
||||
dump(node->getProto());
|
||||
dump(node->getBody());
|
||||
}
|
||||
|
||||
/// Print a module, actually loop over the functions and print them in sequence.
|
||||
void ASTDumper::dump(ModuleAST *Node) {
|
||||
void ASTDumper::dump(ModuleAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Module:\n";
|
||||
for (auto &F : *Node)
|
||||
dump(&F);
|
||||
for (auto &f : *node)
|
||||
dump(&f);
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
|
|
@ -66,20 +66,20 @@ static cl::opt<enum Action> emitAction(
|
|||
cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine",
|
||||
"output the MLIR dump after affine lowering")));
|
||||
|
||||
static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
|
||||
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
|
||||
|
||||
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
||||
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||
if (std::error_code EC = FileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
if (std::error_code ec = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
auto buffer = FileOrErr.get()->getBuffer();
|
||||
auto buffer = fileOrErr.get()->getBuffer();
|
||||
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
|
||||
Parser parser(lexer);
|
||||
return parser.ParseModule();
|
||||
return parser.parseModule();
|
||||
}
|
||||
|
||||
int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
||||
|
@ -128,7 +128,7 @@ int dumpMLIR() {
|
|||
// Check to see what granularity of MLIR we are compiling to.
|
||||
bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
|
||||
|
||||
if (EnableOpt || isLoweringToAffine) {
|
||||
if (enableOpt || isLoweringToAffine) {
|
||||
// Inline all functions into main and then delete them.
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
|
||||
|
@ -150,7 +150,7 @@ int dumpMLIR() {
|
|||
optPM.addPass(mlir::createCSEPass());
|
||||
|
||||
// Add optimizations if enabled.
|
||||
if (EnableOpt) {
|
||||
if (enableOpt) {
|
||||
optPM.addPass(mlir::createLoopFusionPass());
|
||||
optPM.addPass(mlir::createMemRefDataFlowOptPass());
|
||||
}
|
||||
|
|
|
@ -54,7 +54,6 @@ public:
|
|||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
: kind(kind), location(location) {}
|
||||
|
||||
virtual ~ExprAST() = default;
|
||||
|
||||
ExprASTKind getKind() const { return kind; }
|
||||
|
@ -74,12 +73,12 @@ class NumberExprAST : public ExprAST {
|
|||
double Val;
|
||||
|
||||
public:
|
||||
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
|
||||
NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
|
||||
|
||||
double getValue() { return Val; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
/// Expression class for a literal value.
|
||||
|
@ -93,10 +92,11 @@ public:
|
|||
: ExprAST(Expr_Literal, loc), values(std::move(values)),
|
||||
dims(std::move(dims)) {}
|
||||
|
||||
std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
|
||||
std::vector<int64_t> &getDims() { return dims; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
|
||||
llvm::ArrayRef<int64_t> getDims() { return dims; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
|
||||
};
|
||||
|
||||
/// Expression class for referencing a variable, like "a".
|
||||
|
@ -104,13 +104,13 @@ class VariableExprAST : public ExprAST {
|
|||
std::string name;
|
||||
|
||||
public:
|
||||
VariableExprAST(Location loc, const std::string &name)
|
||||
VariableExprAST(Location loc, llvm::StringRef name)
|
||||
: ExprAST(Expr_Var, loc), name(name) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
/// Expression class for defining a variable.
|
||||
|
@ -120,17 +120,17 @@ class VarDeclExprAST : public ExprAST {
|
|||
std::unique_ptr<ExprAST> initVal;
|
||||
|
||||
public:
|
||||
VarDeclExprAST(Location loc, const std::string &name, VarType type,
|
||||
VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
|
||||
std::unique_ptr<ExprAST> initVal)
|
||||
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
|
||||
initVal(std::move(initVal)) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
ExprAST *getInitVal() { return initVal.get(); }
|
||||
VarType &getType() { return type; }
|
||||
const VarType &getType() { return type; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
/// Expression class for a return operator.
|
||||
|
@ -144,61 +144,61 @@ public:
|
|||
llvm::Optional<ExprAST *> getExpr() {
|
||||
if (expr.hasValue())
|
||||
return expr->get();
|
||||
return llvm::NoneType();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
|
||||
};
|
||||
|
||||
/// Expression class for a binary operator.
|
||||
class BinaryExprAST : public ExprAST {
|
||||
char Op;
|
||||
std::unique_ptr<ExprAST> LHS, RHS;
|
||||
char op;
|
||||
std::unique_ptr<ExprAST> lhs, rhs;
|
||||
|
||||
public:
|
||||
char getOp() { return Op; }
|
||||
ExprAST *getLHS() { return LHS.get(); }
|
||||
ExprAST *getRHS() { return RHS.get(); }
|
||||
char getOp() { return op; }
|
||||
ExprAST *getLHS() { return lhs.get(); }
|
||||
ExprAST *getRHS() { return rhs.get(); }
|
||||
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
|
||||
std::unique_ptr<ExprAST> RHS)
|
||||
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
|
||||
RHS(std::move(RHS)) {}
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
|
||||
std::unique_ptr<ExprAST> rhs)
|
||||
: ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
|
||||
rhs(std::move(rhs)) {}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
|
||||
};
|
||||
|
||||
/// Expression class for function calls.
|
||||
class CallExprAST : public ExprAST {
|
||||
std::string Callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::string callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
|
||||
public:
|
||||
CallExprAST(Location loc, const std::string &Callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> Args)
|
||||
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
|
||||
CallExprAST(Location loc, const std::string &callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> args)
|
||||
: ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
|
||||
|
||||
llvm::StringRef getCallee() { return Callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
|
||||
llvm::StringRef getCallee() { return callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
|
||||
};
|
||||
|
||||
/// Expression class for builtin print calls.
|
||||
class PrintExprAST : public ExprAST {
|
||||
std::unique_ptr<ExprAST> Arg;
|
||||
std::unique_ptr<ExprAST> arg;
|
||||
|
||||
public:
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
|
||||
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
|
||||
: ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
|
||||
|
||||
ExprAST *getArg() { return Arg.get(); }
|
||||
ExprAST *getArg() { return arg.get(); }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
|
||||
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
|
||||
};
|
||||
|
||||
/// This class represents the "prototype" for a function, which captures its
|
||||
|
@ -215,23 +215,21 @@ public:
|
|||
: location(location), name(name), args(std::move(args)) {}
|
||||
|
||||
const Location &loc() { return location; }
|
||||
const std::string &getName() const { return name; }
|
||||
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
|
||||
return args;
|
||||
}
|
||||
llvm::StringRef getName() const { return name; }
|
||||
llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
|
||||
};
|
||||
|
||||
/// This class represents a function definition itself.
|
||||
class FunctionAST {
|
||||
std::unique_ptr<PrototypeAST> Proto;
|
||||
std::unique_ptr<ExprASTList> Body;
|
||||
std::unique_ptr<PrototypeAST> proto;
|
||||
std::unique_ptr<ExprASTList> body;
|
||||
|
||||
public:
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
|
||||
std::unique_ptr<ExprASTList> Body)
|
||||
: Proto(std::move(Proto)), Body(std::move(Body)) {}
|
||||
PrototypeAST *getProto() { return Proto.get(); }
|
||||
ExprASTList *getBody() { return Body.get(); }
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> proto,
|
||||
std::unique_ptr<ExprASTList> body)
|
||||
: proto(std::move(proto)), body(std::move(body)) {}
|
||||
PrototypeAST *getProto() { return proto.get(); }
|
||||
ExprASTList *getBody() { return body.get(); }
|
||||
};
|
||||
|
||||
/// This class represents a list of functions to be processed together
|
||||
|
|
|
@ -89,13 +89,13 @@ public:
|
|||
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
|
||||
llvm::StringRef getId() {
|
||||
assert(curTok == tok_identifier);
|
||||
return IdentifierStr;
|
||||
return identifierStr;
|
||||
}
|
||||
|
||||
/// Return the current number (prereq: getCurToken() == tok_number)
|
||||
double getValue() {
|
||||
assert(curTok == tok_number);
|
||||
return NumVal;
|
||||
return numVal;
|
||||
}
|
||||
|
||||
/// Return the location for the beginning of the current token.
|
||||
|
@ -135,56 +135,58 @@ private:
|
|||
/// Return the next token from standard input.
|
||||
Token getTok() {
|
||||
// Skip any whitespace.
|
||||
while (isspace(LastChar))
|
||||
LastChar = Token(getNextChar());
|
||||
while (isspace(lastChar))
|
||||
lastChar = Token(getNextChar());
|
||||
|
||||
// Save the current location before reading the token characters.
|
||||
lastLocation.line = curLineNum;
|
||||
lastLocation.col = curCol;
|
||||
|
||||
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
IdentifierStr = (char)LastChar;
|
||||
while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
|
||||
IdentifierStr += (char)LastChar;
|
||||
// Identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
if (isalpha(lastChar)) {
|
||||
identifierStr = (char)lastChar;
|
||||
while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_')
|
||||
identifierStr += (char)lastChar;
|
||||
|
||||
if (IdentifierStr == "return")
|
||||
if (identifierStr == "return")
|
||||
return tok_return;
|
||||
if (IdentifierStr == "def")
|
||||
if (identifierStr == "def")
|
||||
return tok_def;
|
||||
if (IdentifierStr == "var")
|
||||
if (identifierStr == "var")
|
||||
return tok_var;
|
||||
return tok_identifier;
|
||||
}
|
||||
|
||||
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
|
||||
std::string NumStr;
|
||||
// Number: [0-9.]+
|
||||
if (isdigit(lastChar) || lastChar == '.') {
|
||||
std::string numStr;
|
||||
do {
|
||||
NumStr += LastChar;
|
||||
LastChar = Token(getNextChar());
|
||||
} while (isdigit(LastChar) || LastChar == '.');
|
||||
numStr += lastChar;
|
||||
lastChar = Token(getNextChar());
|
||||
} while (isdigit(lastChar) || lastChar == '.');
|
||||
|
||||
NumVal = strtod(NumStr.c_str(), nullptr);
|
||||
numVal = strtod(numStr.c_str(), nullptr);
|
||||
return tok_number;
|
||||
}
|
||||
|
||||
if (LastChar == '#') {
|
||||
if (lastChar == '#') {
|
||||
// Comment until end of line.
|
||||
do
|
||||
LastChar = Token(getNextChar());
|
||||
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
|
||||
do {
|
||||
lastChar = Token(getNextChar());
|
||||
} while (lastChar != EOF && lastChar != '\n' && lastChar != '\r');
|
||||
|
||||
if (LastChar != EOF)
|
||||
if (lastChar != EOF)
|
||||
return getTok();
|
||||
}
|
||||
|
||||
// Check for end of file. Don't eat the EOF.
|
||||
if (LastChar == EOF)
|
||||
if (lastChar == EOF)
|
||||
return tok_eof;
|
||||
|
||||
// Otherwise, just return the character as its ascii value.
|
||||
Token ThisChar = Token(LastChar);
|
||||
LastChar = Token(getNextChar());
|
||||
return ThisChar;
|
||||
Token thisChar = Token(lastChar);
|
||||
lastChar = Token(getNextChar());
|
||||
return thisChar;
|
||||
}
|
||||
|
||||
/// The last token read from the input.
|
||||
|
@ -194,15 +196,15 @@ private:
|
|||
Location lastLocation;
|
||||
|
||||
/// If the current Token is an identifier, this string contains the value.
|
||||
std::string IdentifierStr;
|
||||
std::string identifierStr;
|
||||
|
||||
/// If the current Token is a number, this contains the value.
|
||||
double NumVal = 0;
|
||||
double numVal = 0;
|
||||
|
||||
/// The last value returned by getNextChar(). We need to keep it around as we
|
||||
/// always need to read ahead one character to decide when to end a token and
|
||||
/// we can't put it back in the stream after reading from it.
|
||||
Token LastChar = Token(' ');
|
||||
Token lastChar = Token(' ');
|
||||
|
||||
/// Keep track of the current line number in the input stream
|
||||
int curLineNum = 0;
|
||||
|
|
|
@ -48,13 +48,13 @@ public:
|
|||
Parser(Lexer &lexer) : lexer(lexer) {}
|
||||
|
||||
/// Parse a full Module. A module is a list of function definitions.
|
||||
std::unique_ptr<ModuleAST> ParseModule() {
|
||||
std::unique_ptr<ModuleAST> parseModule() {
|
||||
lexer.getNextToken(); // prime the lexer
|
||||
|
||||
// Parse functions one at a time and accumulate in this vector.
|
||||
std::vector<FunctionAST> functions;
|
||||
while (auto F = ParseDefinition()) {
|
||||
functions.push_back(std::move(*F));
|
||||
while (auto f = parseDefinition()) {
|
||||
functions.push_back(std::move(*f));
|
||||
if (lexer.getCurToken() == tok_eof)
|
||||
break;
|
||||
}
|
||||
|
@ -70,14 +70,14 @@ private:
|
|||
|
||||
/// Parse a return statement.
|
||||
/// return :== return ; | return expr ;
|
||||
std::unique_ptr<ReturnExprAST> ParseReturn() {
|
||||
std::unique_ptr<ReturnExprAST> parseReturn() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_return);
|
||||
|
||||
// return takes an optional argument
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
if (lexer.getCurToken() != ';') {
|
||||
expr = ParseExpression();
|
||||
expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -86,18 +86,18 @@ private:
|
|||
|
||||
/// Parse a literal number.
|
||||
/// numberexpr ::= number
|
||||
std::unique_ptr<ExprAST> ParseNumberExpr() {
|
||||
std::unique_ptr<ExprAST> parseNumberExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
auto Result =
|
||||
auto result =
|
||||
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
|
||||
lexer.consume(tok_number);
|
||||
return std::move(Result);
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/// Parse a literal array expression.
|
||||
/// tensorLiteral ::= [ literalList ] | number
|
||||
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
||||
std::unique_ptr<ExprAST> ParseTensorLiteralExpr() {
|
||||
std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(Token('['));
|
||||
|
||||
|
@ -108,13 +108,13 @@ private:
|
|||
do {
|
||||
// We can have either another nested array or a number literal.
|
||||
if (lexer.getCurToken() == '[') {
|
||||
values.push_back(ParseTensorLiteralExpr());
|
||||
values.push_back(parseTensorLiteralExpr());
|
||||
if (!values.back())
|
||||
return nullptr; // parse error in the nested array.
|
||||
} else {
|
||||
if (lexer.getCurToken() != tok_number)
|
||||
return parseError<ExprAST>("<num> or [", "in literal expression");
|
||||
values.push_back(ParseNumberExpr());
|
||||
values.push_back(parseNumberExpr());
|
||||
}
|
||||
|
||||
// End of this list on ']'
|
||||
|
@ -130,8 +130,10 @@ private:
|
|||
if (values.empty())
|
||||
return parseError<ExprAST>("<something>", "to fill literal expression");
|
||||
lexer.getNextToken(); // eat ]
|
||||
|
||||
/// Fill in the dimensions now. First the current nesting level:
|
||||
dims.push_back(values.size());
|
||||
|
||||
/// If there is any nested array, process all of them and ensure that
|
||||
/// dimensions are uniform.
|
||||
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
|
||||
|
@ -143,7 +145,7 @@ private:
|
|||
"inside literal expression");
|
||||
|
||||
// Append the nested dimensions to the current level
|
||||
auto &firstDims = firstLiteral->getDims();
|
||||
auto firstDims = firstLiteral->getDims();
|
||||
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
||||
|
||||
// Sanity check that shape is uniform across all elements of the list.
|
||||
|
@ -162,22 +164,22 @@ private:
|
|||
}
|
||||
|
||||
/// parenexpr ::= '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseParenExpr() {
|
||||
std::unique_ptr<ExprAST> parseParenExpr() {
|
||||
lexer.getNextToken(); // eat (.
|
||||
auto V = ParseExpression();
|
||||
if (!V)
|
||||
auto v = parseExpression();
|
||||
if (!v)
|
||||
return nullptr;
|
||||
|
||||
if (lexer.getCurToken() != ')')
|
||||
return parseError<ExprAST>(")", "to close expression with parentheses");
|
||||
lexer.consume(Token(')'));
|
||||
return V;
|
||||
return v;
|
||||
}
|
||||
|
||||
/// identifierexpr
|
||||
/// ::= identifier
|
||||
/// ::= identifier '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
|
||||
std::unique_ptr<ExprAST> parseIdentifierExpr() {
|
||||
std::string name = lexer.getId();
|
||||
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -188,11 +190,11 @@ private:
|
|||
|
||||
// This is a function call.
|
||||
lexer.consume(Token('('));
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
std::vector<std::unique_ptr<ExprAST>> args;
|
||||
if (lexer.getCurToken() != ')') {
|
||||
while (true) {
|
||||
if (auto Arg = ParseExpression())
|
||||
Args.push_back(std::move(Arg));
|
||||
if (auto arg = parseExpression())
|
||||
args.push_back(std::move(arg));
|
||||
else
|
||||
return nullptr;
|
||||
|
||||
|
@ -208,14 +210,14 @@ private:
|
|||
|
||||
// It can be a builtin call to print
|
||||
if (name == "print") {
|
||||
if (Args.size() != 1)
|
||||
if (args.size() != 1)
|
||||
return parseError<ExprAST>("<single arg>", "as argument to print()");
|
||||
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(Args[0]));
|
||||
return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
|
||||
}
|
||||
|
||||
// Call to a user-defined function
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(Args));
|
||||
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
|
||||
}
|
||||
|
||||
/// primary
|
||||
|
@ -223,20 +225,20 @@ private:
|
|||
/// ::= numberexpr
|
||||
/// ::= parenexpr
|
||||
/// ::= tensorliteral
|
||||
std::unique_ptr<ExprAST> ParsePrimary() {
|
||||
std::unique_ptr<ExprAST> parsePrimary() {
|
||||
switch (lexer.getCurToken()) {
|
||||
default:
|
||||
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
||||
<< "' when expecting an expression\n";
|
||||
return nullptr;
|
||||
case tok_identifier:
|
||||
return ParseIdentifierExpr();
|
||||
return parseIdentifierExpr();
|
||||
case tok_number:
|
||||
return ParseNumberExpr();
|
||||
return parseNumberExpr();
|
||||
case '(':
|
||||
return ParseParenExpr();
|
||||
return parseParenExpr();
|
||||
case '[':
|
||||
return ParseTensorLiteralExpr();
|
||||
return parseTensorLiteralExpr();
|
||||
case ';':
|
||||
return nullptr;
|
||||
case '}':
|
||||
|
@ -248,54 +250,54 @@ private:
|
|||
/// argument indicates the precedence of the current binary operator.
|
||||
///
|
||||
/// binoprhs ::= ('+' primary)*
|
||||
std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
|
||||
std::unique_ptr<ExprAST> LHS) {
|
||||
std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
|
||||
std::unique_ptr<ExprAST> lhs) {
|
||||
// If this is a binop, find its precedence.
|
||||
while (true) {
|
||||
int TokPrec = GetTokPrecedence();
|
||||
int tokPrec = getTokPrecedence();
|
||||
|
||||
// If this is a binop that binds at least as tightly as the current binop,
|
||||
// consume it, otherwise we are done.
|
||||
if (TokPrec < ExprPrec)
|
||||
return LHS;
|
||||
if (tokPrec < exprPrec)
|
||||
return lhs;
|
||||
|
||||
// Okay, we know this is a binop.
|
||||
int BinOp = lexer.getCurToken();
|
||||
lexer.consume(Token(BinOp));
|
||||
int binOp = lexer.getCurToken();
|
||||
lexer.consume(Token(binOp));
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
||||
// Parse the primary expression after the binary operator.
|
||||
auto RHS = ParsePrimary();
|
||||
if (!RHS)
|
||||
auto rhs = parsePrimary();
|
||||
if (!rhs)
|
||||
return parseError<ExprAST>("expression", "to complete binary operator");
|
||||
|
||||
// If BinOp binds less tightly with RHS than the operator after RHS, let
|
||||
// the pending operator take RHS as its LHS.
|
||||
int NextPrec = GetTokPrecedence();
|
||||
if (TokPrec < NextPrec) {
|
||||
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
|
||||
if (!RHS)
|
||||
// If BinOp binds less tightly with rhs than the operator after rhs, let
|
||||
// the pending operator take rhs as its lhs.
|
||||
int nextPrec = getTokPrecedence();
|
||||
if (tokPrec < nextPrec) {
|
||||
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Merge LHS/RHS.
|
||||
LHS = std::make_unique<BinaryExprAST>(std::move(loc), BinOp,
|
||||
std::move(LHS), std::move(RHS));
|
||||
// Merge lhs/RHS.
|
||||
lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
|
||||
std::move(lhs), std::move(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
/// expression::= primary binoprhs
|
||||
std::unique_ptr<ExprAST> ParseExpression() {
|
||||
auto LHS = ParsePrimary();
|
||||
if (!LHS)
|
||||
/// expression::= primary binop rhs
|
||||
std::unique_ptr<ExprAST> parseExpression() {
|
||||
auto lhs = parsePrimary();
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
|
||||
return ParseBinOpRHS(0, std::move(LHS));
|
||||
return parseBinOpRHS(0, std::move(lhs));
|
||||
}
|
||||
|
||||
/// type ::= < shape_list >
|
||||
/// shape_list ::= num | num , shape_list
|
||||
std::unique_ptr<VarType> ParseType() {
|
||||
std::unique_ptr<VarType> parseType() {
|
||||
if (lexer.getCurToken() != '<')
|
||||
return parseError<VarType>("<", "to begin type");
|
||||
lexer.getNextToken(); // eat <
|
||||
|
@ -319,7 +321,7 @@ private:
|
|||
/// and identifier and an optional type (shape specification) before the
|
||||
/// initializer.
|
||||
/// decl ::= var identifier [ type ] = expr
|
||||
std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
|
||||
std::unique_ptr<VarDeclExprAST> parseDeclaration() {
|
||||
if (lexer.getCurToken() != tok_var)
|
||||
return parseError<VarDeclExprAST>("var", "to begin declaration");
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
@ -333,7 +335,7 @@ private:
|
|||
|
||||
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
|
||||
if (lexer.getCurToken() == '<') {
|
||||
type = ParseType();
|
||||
type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -341,7 +343,7 @@ private:
|
|||
if (!type)
|
||||
type = std::make_unique<VarType>();
|
||||
lexer.consume(Token('='));
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
|
||||
std::move(*type), std::move(expr));
|
||||
}
|
||||
|
@ -352,7 +354,7 @@ private:
|
|||
/// block ::= { expression_list }
|
||||
/// expression_list ::= block_expr ; expression_list
|
||||
/// block_expr ::= decl | "return" | expr
|
||||
std::unique_ptr<ExprASTList> ParseBlock() {
|
||||
std::unique_ptr<ExprASTList> parseBlock() {
|
||||
if (lexer.getCurToken() != '{')
|
||||
return parseError<ExprASTList>("{", "to begin block");
|
||||
lexer.consume(Token('{'));
|
||||
|
@ -366,19 +368,19 @@ private:
|
|||
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
|
||||
if (lexer.getCurToken() == tok_var) {
|
||||
// Variable declaration
|
||||
auto varDecl = ParseDeclaration();
|
||||
auto varDecl = parseDeclaration();
|
||||
if (!varDecl)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(varDecl));
|
||||
} else if (lexer.getCurToken() == tok_return) {
|
||||
// Return statement
|
||||
auto ret = ParseReturn();
|
||||
auto ret = parseReturn();
|
||||
if (!ret)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(ret));
|
||||
} else {
|
||||
// General expression
|
||||
auto expr = ParseExpression();
|
||||
auto expr = parseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(expr));
|
||||
|
@ -401,13 +403,13 @@ private:
|
|||
|
||||
/// prototype ::= def id '(' decl_list ')'
|
||||
/// decl_list ::= identifier | identifier, decl_list
|
||||
std::unique_ptr<PrototypeAST> ParsePrototype() {
|
||||
std::unique_ptr<PrototypeAST> parsePrototype() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_def);
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<PrototypeAST>("function name", "in prototype");
|
||||
|
||||
std::string FnName = lexer.getId();
|
||||
std::string fnName = lexer.getId();
|
||||
lexer.consume(tok_identifier);
|
||||
|
||||
if (lexer.getCurToken() != '(')
|
||||
|
@ -435,7 +437,7 @@ private:
|
|||
|
||||
// success.
|
||||
lexer.consume(Token(')'));
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), FnName,
|
||||
return std::make_unique<PrototypeAST>(std::move(loc), fnName,
|
||||
std::move(args));
|
||||
}
|
||||
|
||||
|
@ -443,18 +445,18 @@ private:
|
|||
/// `def` keyword, followed by a block containing a list of expressions.
|
||||
///
|
||||
/// definition ::= prototype block
|
||||
std::unique_ptr<FunctionAST> ParseDefinition() {
|
||||
auto Proto = ParsePrototype();
|
||||
if (!Proto)
|
||||
std::unique_ptr<FunctionAST> parseDefinition() {
|
||||
auto proto = parsePrototype();
|
||||
if (!proto)
|
||||
return nullptr;
|
||||
|
||||
if (auto block = ParseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(Proto), std::move(block));
|
||||
if (auto block = parseBlock())
|
||||
return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Get the precedence of the pending binary operator token.
|
||||
int GetTokPrecedence() {
|
||||
int getTokPrecedence() {
|
||||
if (!isascii(lexer.getCurToken()))
|
||||
return -1;
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ private:
|
|||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "toy/AST.h"
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -40,22 +41,22 @@ struct Indent {
|
|||
/// the way. The only data member is the current indentation level.
|
||||
class ASTDumper {
|
||||
public:
|
||||
void dump(ModuleAST *Node);
|
||||
void dump(ModuleAST *node);
|
||||
|
||||
private:
|
||||
void dump(VarType &type);
|
||||
void dump(const VarType &type);
|
||||
void dump(VarDeclExprAST *varDecl);
|
||||
void dump(ExprAST *expr);
|
||||
void dump(ExprASTList *exprList);
|
||||
void dump(NumberExprAST *num);
|
||||
void dump(LiteralExprAST *Node);
|
||||
void dump(VariableExprAST *Node);
|
||||
void dump(ReturnExprAST *Node);
|
||||
void dump(BinaryExprAST *Node);
|
||||
void dump(CallExprAST *Node);
|
||||
void dump(PrintExprAST *Node);
|
||||
void dump(PrototypeAST *Node);
|
||||
void dump(FunctionAST *Node);
|
||||
void dump(LiteralExprAST *node);
|
||||
void dump(VariableExprAST *node);
|
||||
void dump(ReturnExprAST *node);
|
||||
void dump(BinaryExprAST *node);
|
||||
void dump(CallExprAST *node);
|
||||
void dump(PrintExprAST *node);
|
||||
void dump(PrototypeAST *node);
|
||||
void dump(FunctionAST *node);
|
||||
|
||||
// Actually print spaces matching the current indentation level
|
||||
void indent() {
|
||||
|
@ -68,8 +69,8 @@ private:
|
|||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *Node) {
|
||||
const auto &loc = Node->loc();
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
.str();
|
||||
|
@ -129,56 +130,46 @@ void ASTDumper::dump(NumberExprAST *num) {
|
|||
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||
/// We print out such array with the dimensions spelled out at every level:
|
||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||
void printLitHelper(ExprAST *lit_or_num) {
|
||||
void printLitHelper(ExprAST *litOrNum) {
|
||||
// Inside a literal expression we can have either a number or another literal
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||
llvm::errs() << num->getValue();
|
||||
return;
|
||||
}
|
||||
auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
|
||||
auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
|
||||
|
||||
// Print the dimension for this literal first
|
||||
llvm::errs() << "<";
|
||||
{
|
||||
const char *sep = "";
|
||||
for (auto dim : literal->getDims()) {
|
||||
llvm::errs() << sep << dim;
|
||||
sep = ", ";
|
||||
}
|
||||
}
|
||||
mlir::interleaveComma(literal->getDims(), llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
|
||||
// Now print the content, recursing on every element of the list
|
||||
llvm::errs() << "[ ";
|
||||
const char *sep = "";
|
||||
for (auto &elt : literal->getValues()) {
|
||||
llvm::errs() << sep;
|
||||
printLitHelper(elt.get());
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(literal->getValues(), llvm::errs(),
|
||||
[&](auto &elt) { printLitHelper(elt.get()); });
|
||||
llvm::errs() << "]";
|
||||
}
|
||||
|
||||
/// Print a literal, see the recursive helper above for the implementation.
|
||||
void ASTDumper::dump(LiteralExprAST *Node) {
|
||||
void ASTDumper::dump(LiteralExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Literal: ";
|
||||
printLitHelper(Node);
|
||||
llvm::errs() << " " << loc(Node) << "\n";
|
||||
printLitHelper(node);
|
||||
llvm::errs() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Print a variable reference (just a name).
|
||||
void ASTDumper::dump(VariableExprAST *Node) {
|
||||
void ASTDumper::dump(VariableExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
|
||||
llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
|
||||
}
|
||||
|
||||
/// Return statement print the return and its (optional) argument.
|
||||
void ASTDumper::dump(ReturnExprAST *Node) {
|
||||
void ASTDumper::dump(ReturnExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Return\n";
|
||||
if (Node->getExpr().hasValue())
|
||||
return dump(*Node->getExpr());
|
||||
if (node->getExpr().hasValue())
|
||||
return dump(*node->getExpr());
|
||||
{
|
||||
INDENT();
|
||||
llvm::errs() << "(void)\n";
|
||||
|
@ -186,73 +177,66 @@ void ASTDumper::dump(ReturnExprAST *Node) {
|
|||
}
|
||||
|
||||
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
|
||||
void ASTDumper::dump(BinaryExprAST *Node) {
|
||||
void ASTDumper::dump(BinaryExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
|
||||
dump(Node->getLHS());
|
||||
dump(Node->getRHS());
|
||||
llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
|
||||
dump(node->getLHS());
|
||||
dump(node->getRHS());
|
||||
}
|
||||
|
||||
/// Print a call expression, first the callee name and the list of args by
|
||||
/// recursing into each individual argument.
|
||||
void ASTDumper::dump(CallExprAST *Node) {
|
||||
void ASTDumper::dump(CallExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
|
||||
for (auto &arg : Node->getArgs())
|
||||
llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
|
||||
for (auto &arg : node->getArgs())
|
||||
dump(arg.get());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a builtin print call, first the builtin name and then the argument.
|
||||
void ASTDumper::dump(PrintExprAST *Node) {
|
||||
void ASTDumper::dump(PrintExprAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Print [ " << loc(Node) << "\n";
|
||||
dump(Node->getArg());
|
||||
llvm::errs() << "Print [ " << loc(node) << "\n";
|
||||
dump(node->getArg());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print type: only the shape is printed in between '<' and '>'
|
||||
void ASTDumper::dump(VarType &type) {
|
||||
void ASTDumper::dump(const VarType &type) {
|
||||
llvm::errs() << "<";
|
||||
const char *sep = "";
|
||||
for (auto shape : type.shape) {
|
||||
llvm::errs() << sep << shape;
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(type.shape, llvm::errs());
|
||||
llvm::errs() << ">";
|
||||
}
|
||||
|
||||
/// Print a function prototype, first the function name, and then the list of
|
||||
/// parameters names.
|
||||
void ASTDumper::dump(PrototypeAST *Node) {
|
||||
void ASTDumper::dump(PrototypeAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
|
||||
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
|
||||
indent();
|
||||
llvm::errs() << "Params: [";
|
||||
const char *sep = "";
|
||||
for (auto &arg : Node->getArgs()) {
|
||||
llvm::errs() << sep << arg->getName();
|
||||
sep = ", ";
|
||||
}
|
||||
mlir::interleaveComma(node->getArgs(), llvm::errs(),
|
||||
[](auto &arg) { llvm::errs() << arg->getName(); });
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a function, first the prototype and then the body.
|
||||
void ASTDumper::dump(FunctionAST *Node) {
|
||||
void ASTDumper::dump(FunctionAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Function \n";
|
||||
dump(Node->getProto());
|
||||
dump(Node->getBody());
|
||||
dump(node->getProto());
|
||||
dump(node->getBody());
|
||||
}
|
||||
|
||||
/// Print a module, actually loop over the functions and print them in sequence.
|
||||
void ASTDumper::dump(ModuleAST *Node) {
|
||||
void ASTDumper::dump(ModuleAST *node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Module:\n";
|
||||
for (auto &F : *Node)
|
||||
dump(&F);
|
||||
for (auto &f : *node)
|
||||
dump(&f);
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
|
|
@ -85,20 +85,20 @@ static cl::opt<enum Action> emitAction(
|
|||
clEnumValN(RunJIT, "jit",
|
||||
"JIT the code and run it by invoking the main function")));
|
||||
|
||||
static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
|
||||
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
|
||||
|
||||
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
||||
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||
if (std::error_code EC = FileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
if (std::error_code ec = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
auto buffer = FileOrErr.get()->getBuffer();
|
||||
auto buffer = fileOrErr.get()->getBuffer();
|
||||
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
|
||||
Parser parser(lexer);
|
||||
return parser.ParseModule();
|
||||
return parser.parseModule();
|
||||
}
|
||||
|
||||
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||
|
@ -142,7 +142,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
|
|||
bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
|
||||
bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM;
|
||||
|
||||
if (EnableOpt || isLoweringToAffine) {
|
||||
if (enableOpt || isLoweringToAffine) {
|
||||
// Inline all functions into main and then delete them.
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
|
||||
|
@ -164,7 +164,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
|
|||
optPM.addPass(mlir::createCSEPass());
|
||||
|
||||
// Add optimizations if enabled.
|
||||
if (EnableOpt) {
|
||||
if (enableOpt) {
|
||||
optPM.addPass(mlir::createLoopFusionPass());
|
||||
optPM.addPass(mlir::createMemRefDataFlowOptPass());
|
||||
}
|
||||
|
@ -208,7 +208,7 @@ int dumpLLVMIR(mlir::ModuleOp module) {
|
|||
|
||||
/// Optionally run an optimization pipeline over the llvm module.
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
|
||||
/*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
|
||||
/*targetMachine=*/nullptr);
|
||||
if (auto err = optPipeline(llvmModule.get())) {
|
||||
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
|
||||
|
@ -225,7 +225,7 @@ int runJit(mlir::ModuleOp module) {
|
|||
|
||||
// An optimization pipeline to use within the execution engine.
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
|
||||
/*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
|
||||
/*targetMachine=*/nullptr);
|
||||
|
||||
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
|
||||
|
|
Loading…
Reference in New Issue