forked from OSchip/llvm-project
Initial version for chapter 1 of the Toy tutorial
-- PiperOrigin-RevId: 241549247
This commit is contained in:
parent
7c1fc9e795
commit
38b71d6b84
mlir
|
@ -43,3 +43,7 @@ add_subdirectory(lib)
|
|||
add_subdirectory(tools)
|
||||
add_subdirectory(unittests)
|
||||
add_subdirectory(test)
|
||||
|
||||
if( LLVM_INCLUDE_EXAMPLES )
|
||||
add_subdirectory(examples)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
add_subdirectory(toy)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
add_custom_target(Toy)
|
||||
set_target_properties(Toy PROPERTIES FOLDER Examples)
|
||||
|
||||
macro(add_toy_chapter name)
|
||||
add_dependencies(Toy ${name})
|
||||
add_llvm_example(${name} ${ARGN})
|
||||
endmacro(add_toy_chapter name)
|
||||
|
||||
add_subdirectory(Ch1)
|
|
@ -0,0 +1,9 @@
|
|||
set(LLVM_LINK_COMPONENTS
|
||||
Support
|
||||
)
|
||||
|
||||
add_toy_chapter(toyc-ch1
|
||||
toyc.cpp
|
||||
parser/AST.cpp
|
||||
)
|
||||
include_directories(include/)
|
|
@ -0,0 +1,256 @@
|
|||
//===- AST.h - Node definition for the Toy AST ----------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the AST for the Toy language. It is optimized for
|
||||
// simplicity, not efficiency. The AST forms a tree structure where each node
|
||||
// references its children using std::unique_ptr<>.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TUTORIAL_TOY_AST_H_
|
||||
#define MLIR_TUTORIAL_TOY_AST_H_
|
||||
|
||||
#include "toy/Lexer.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <vector>
|
||||
|
||||
namespace toy {
|
||||
|
||||
/// A variable
|
||||
struct VarType {
|
||||
enum { TY_FLOAT, TY_INT } elt_ty;
|
||||
std::vector<int> shape;
|
||||
};
|
||||
|
||||
/// Base class for all expression nodes.
|
||||
class ExprAST {
|
||||
public:
|
||||
enum ExprASTKind {
|
||||
Expr_VarDecl,
|
||||
Expr_Return,
|
||||
Expr_Num,
|
||||
Expr_Literal,
|
||||
Expr_Var,
|
||||
Expr_BinOp,
|
||||
Expr_Call,
|
||||
Expr_Print, // builtin
|
||||
Expr_If,
|
||||
Expr_For,
|
||||
};
|
||||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
: kind(kind), location(location) {}
|
||||
|
||||
virtual ~ExprAST() = default;
|
||||
|
||||
ExprASTKind getKind() const { return kind; }
|
||||
|
||||
const Location &loc() { return location; }
|
||||
|
||||
private:
|
||||
const ExprASTKind kind;
|
||||
Location location;
|
||||
};
|
||||
|
||||
/// A block-list of expressions.
|
||||
using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
|
||||
|
||||
/// Expression class for numeric literals like "1.0".
|
||||
class NumberExprAST : public ExprAST {
|
||||
double Val;
|
||||
|
||||
public:
|
||||
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
|
||||
|
||||
double getValue() { return Val; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
///
|
||||
class LiteralExprAST : public ExprAST {
|
||||
std::vector<std::unique_ptr<ExprAST>> values;
|
||||
std::vector<int64_t> dims;
|
||||
|
||||
public:
|
||||
LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
|
||||
std::vector<int64_t> dims)
|
||||
: ExprAST(Expr_Literal, loc), values(std::move(values)),
|
||||
dims(std::move(dims)) {}
|
||||
|
||||
std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
|
||||
std::vector<int64_t> &getDims() { return dims; }
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
|
||||
};
|
||||
|
||||
/// Expression class for referencing a variable, like "a".
|
||||
class VariableExprAST : public ExprAST {
|
||||
std::string name;
|
||||
|
||||
public:
|
||||
VariableExprAST(Location loc, const std::string &name)
|
||||
: ExprAST(Expr_Var, loc), name(name) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
///
|
||||
class VarDeclExprAST : public ExprAST {
|
||||
std::string name;
|
||||
VarType type;
|
||||
std::unique_ptr<ExprAST> initVal;
|
||||
|
||||
public:
|
||||
VarDeclExprAST(Location loc, const std::string &name, VarType type,
|
||||
std::unique_ptr<ExprAST> initVal)
|
||||
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
|
||||
initVal(std::move(initVal)) {}
|
||||
|
||||
llvm::StringRef getName() { return name; }
|
||||
ExprAST *getInitVal() { return initVal.get(); }
|
||||
VarType &getType() { return type; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
///
|
||||
class ReturnExprAST : public ExprAST {
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
|
||||
public:
|
||||
ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
|
||||
: ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
|
||||
|
||||
llvm::Optional<ExprAST *> getExpr() {
|
||||
if (expr.hasValue())
|
||||
return expr->get();
|
||||
return llvm::NoneType();
|
||||
}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
|
||||
};
|
||||
|
||||
/// Expression class for a binary operator.
|
||||
class BinaryExprAST : public ExprAST {
|
||||
char Op;
|
||||
std::unique_ptr<ExprAST> LHS, RHS;
|
||||
|
||||
public:
|
||||
char getOp() { return Op; }
|
||||
ExprAST *getLHS() { return LHS.get(); }
|
||||
ExprAST *getRHS() { return RHS.get(); }
|
||||
|
||||
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
|
||||
std::unique_ptr<ExprAST> RHS)
|
||||
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
|
||||
RHS(std::move(RHS)) {}
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
|
||||
};
|
||||
|
||||
/// Expression class for function calls.
|
||||
class CallExprAST : public ExprAST {
|
||||
std::string Callee;
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
|
||||
public:
|
||||
CallExprAST(Location loc, const std::string &Callee,
|
||||
std::vector<std::unique_ptr<ExprAST>> Args)
|
||||
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
|
||||
|
||||
llvm::StringRef getCallee() { return Callee; }
|
||||
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
|
||||
};
|
||||
|
||||
/// Expression class for builtin print calls.
|
||||
class PrintExprAST : public ExprAST {
|
||||
std::unique_ptr<ExprAST> Arg;
|
||||
|
||||
public:
|
||||
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
|
||||
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
|
||||
|
||||
ExprAST *getArg() { return Arg.get(); }
|
||||
|
||||
/// LLVM style RTTI
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
|
||||
};
|
||||
|
||||
/// This class represents the "prototype" for a function, which captures its
|
||||
/// name, and its argument names (thus implicitly the number of arguments the
|
||||
/// function takes).
|
||||
class PrototypeAST {
|
||||
Location location;
|
||||
std::string name;
|
||||
std::vector<std::unique_ptr<VariableExprAST>> args;
|
||||
|
||||
public:
|
||||
PrototypeAST(Location location, const std::string &name,
|
||||
std::vector<std::unique_ptr<VariableExprAST>> args)
|
||||
: location(location), name(name), args(std::move(args)) {}
|
||||
|
||||
const Location &loc() { return location; }
|
||||
const std::string &getName() const { return name; }
|
||||
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
/// This class represents a function definition itself.
|
||||
class FunctionAST {
|
||||
std::unique_ptr<PrototypeAST> Proto;
|
||||
std::unique_ptr<ExprASTList> Body;
|
||||
|
||||
public:
|
||||
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
|
||||
std::unique_ptr<ExprASTList> Body)
|
||||
: Proto(std::move(Proto)), Body(std::move(Body)) {}
|
||||
PrototypeAST *getProto() { return Proto.get(); }
|
||||
ExprASTList *getBody() { return Body.get(); }
|
||||
};
|
||||
|
||||
/// This class represents a list of functions to be processed together
|
||||
class ModuleAST {
|
||||
std::vector<FunctionAST> functions;
|
||||
|
||||
public:
|
||||
ModuleAST(std::vector<FunctionAST> functions)
|
||||
: functions(std::move(functions)) {}
|
||||
|
||||
auto begin() -> decltype(functions.begin()) { return functions.begin(); }
|
||||
auto end() -> decltype(functions.end()) { return functions.end(); }
|
||||
};
|
||||
|
||||
void dump(ModuleAST &);
|
||||
|
||||
} // namespace toy
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_AST_H_
|
|
@ -0,0 +1,239 @@
|
|||
//===- Lexer.h - Lexer for the Toy language -------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple Lexer for the Toy language.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TUTORIAL_TOY_LEXER_H_
|
||||
#define MLIR_TUTORIAL_TOY_LEXER_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace toy {
|
||||
|
||||
/// Structure definition a location in a file.
|
||||
struct Location {
|
||||
std::shared_ptr<std::string> file; ///< filename
|
||||
int line; ///< line number.
|
||||
int col; ///< column number.
|
||||
};
|
||||
|
||||
// List of Token returned by the lexer.
|
||||
enum Token : int {
|
||||
tok_semicolon = ';',
|
||||
tok_parenthese_open = '(',
|
||||
tok_parenthese_close = ')',
|
||||
tok_bracket_open = '{',
|
||||
tok_bracket_close = '}',
|
||||
tok_sbracket_open = '[',
|
||||
tok_sbracket_close = ']',
|
||||
|
||||
tok_eof = -1,
|
||||
|
||||
// commands
|
||||
tok_return = -2,
|
||||
tok_var = -3,
|
||||
tok_def = -4,
|
||||
|
||||
// primary
|
||||
tok_identifier = -5,
|
||||
tok_number = -6,
|
||||
};
|
||||
|
||||
/// The Lexer is an abstract base class providing all the facilities that the
|
||||
/// Parser expects. It goes through the stream one token at a time and keeps
|
||||
/// track of the location in the file for debugging purpose.
|
||||
/// It relies on a subclass to provide a `readNextLine()` method. The subclass
|
||||
/// can proceed by reading the next line from the standard input or from a
|
||||
/// memory mapped file.
|
||||
class Lexer {
|
||||
public:
|
||||
/// Create a lexer for the given filename. The filename is kept only for
|
||||
/// debugging purpose (attaching a location to a Token).
|
||||
Lexer(std::string filename)
|
||||
: lastLocation(
|
||||
{std::make_shared<std::string>(std::move(filename)), 0, 0}) {}
|
||||
virtual ~Lexer() = default;
|
||||
|
||||
/// Look at the current token in the stream.
|
||||
Token getCurToken() { return curTok; }
|
||||
|
||||
/// Move to the next token in the stream and return it.
|
||||
Token getNextToken() { return curTok = getTok(); }
|
||||
|
||||
/// Move to the next token in the stream, asserting on the current token
|
||||
/// matching the expectation.
|
||||
void consume(Token tok) {
|
||||
assert(tok == curTok && "consume Token mismatch expectation");
|
||||
getNextToken();
|
||||
}
|
||||
|
||||
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
|
||||
llvm::StringRef getId() {
|
||||
assert(curTok == tok_identifier);
|
||||
return IdentifierStr;
|
||||
}
|
||||
|
||||
/// Return the current number (prereq: getCurToken() == tok_number)
|
||||
double getValue() {
|
||||
assert(curTok == tok_number);
|
||||
return NumVal;
|
||||
}
|
||||
|
||||
/// Return the location for the beginning of the current token.
|
||||
Location getLastLocation() { return lastLocation; }
|
||||
|
||||
// Return the current line in the file.
|
||||
int getLine() { return curLineNum; }
|
||||
|
||||
// Return the current column in the file.
|
||||
int getCol() { return curCol; }
|
||||
|
||||
private:
|
||||
/// Delegate to a derived class fetching the next line. Returns an empty
|
||||
/// string to signal end of file (EOF). Lines are expected to always finish
|
||||
/// with "\n"
|
||||
virtual llvm::StringRef readNextLine() = 0;
|
||||
|
||||
/// Return the next character from the stream. This manages the buffer for the
|
||||
/// current line and request the next line buffer to the derived class as
|
||||
/// needed.
|
||||
int getNextChar() {
|
||||
// The current line buffer should not be empty unless it is the end of file.
|
||||
if (curLineBuffer.empty())
|
||||
return EOF;
|
||||
++curCol;
|
||||
auto nextchar = curLineBuffer.front();
|
||||
curLineBuffer = curLineBuffer.drop_front();
|
||||
if (curLineBuffer.empty())
|
||||
curLineBuffer = readNextLine();
|
||||
if (nextchar == '\n') {
|
||||
++curLineNum;
|
||||
curCol = 0;
|
||||
}
|
||||
return nextchar;
|
||||
}
|
||||
|
||||
/// Return the next token from standard input.
|
||||
Token getTok() {
|
||||
// Skip any whitespace.
|
||||
while (isspace(LastChar))
|
||||
LastChar = Token(getNextChar());
|
||||
|
||||
// Save the current location before reading the token characters.
|
||||
lastLocation.line = curLineNum;
|
||||
lastLocation.col = curCol;
|
||||
|
||||
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
|
||||
IdentifierStr = (char)LastChar;
|
||||
while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
|
||||
IdentifierStr += (char)LastChar;
|
||||
|
||||
if (IdentifierStr == "return")
|
||||
return tok_return;
|
||||
if (IdentifierStr == "def")
|
||||
return tok_def;
|
||||
if (IdentifierStr == "var")
|
||||
return tok_var;
|
||||
return tok_identifier;
|
||||
}
|
||||
|
||||
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
|
||||
std::string NumStr;
|
||||
do {
|
||||
NumStr += LastChar;
|
||||
LastChar = Token(getNextChar());
|
||||
} while (isdigit(LastChar) || LastChar == '.');
|
||||
|
||||
NumVal = strtod(NumStr.c_str(), nullptr);
|
||||
return tok_number;
|
||||
}
|
||||
|
||||
if (LastChar == '#') {
|
||||
// Comment until end of line.
|
||||
do
|
||||
LastChar = Token(getNextChar());
|
||||
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
|
||||
|
||||
if (LastChar != EOF)
|
||||
return getTok();
|
||||
}
|
||||
|
||||
// Check for end of file. Don't eat the EOF.
|
||||
if (LastChar == EOF)
|
||||
return tok_eof;
|
||||
|
||||
// Otherwise, just return the character as its ascii value.
|
||||
Token ThisChar = Token(LastChar);
|
||||
LastChar = Token(getNextChar());
|
||||
return ThisChar;
|
||||
}
|
||||
|
||||
/// The last token read from the input.
|
||||
Token curTok = tok_eof;
|
||||
|
||||
/// Location for `curTok`.
|
||||
Location lastLocation;
|
||||
|
||||
/// If the current Token is an identifier, this string contains the value.
|
||||
std::string IdentifierStr;
|
||||
|
||||
/// If the current Token is a number, this contains the value.
|
||||
double NumVal = 0;
|
||||
|
||||
/// The last value returned by getNextChar(). We need to keep it around as we
|
||||
/// always need to read ahead one character to decide when to end a token and
|
||||
/// we can't put it back in the stream after reading from it.
|
||||
Token LastChar = Token(' ');
|
||||
|
||||
/// Keep track of the current line number in the input stream
|
||||
int curLineNum = 0;
|
||||
|
||||
/// Keep track of the current column number in the input stream
|
||||
int curCol = 0;
|
||||
|
||||
/// Buffer supplied by the derived class on calls to `readNextLine()`
|
||||
llvm::StringRef curLineBuffer = "\n";
|
||||
};
|
||||
|
||||
/// A lexer implementation operating on a buffer in memory.
|
||||
class LexerBuffer final : public Lexer {
|
||||
public:
|
||||
LexerBuffer(const char *begin, const char *end, std::string filename)
|
||||
: Lexer(std::move(filename)), current(begin), end(end) {}
|
||||
|
||||
private:
|
||||
/// Provide one line at a time to the Lexer, return an empty string when
|
||||
/// reaching the end of the buffer.
|
||||
llvm::StringRef readNextLine() override {
|
||||
auto *begin = current;
|
||||
while (current <= end && *current && *current != '\n')
|
||||
++current;
|
||||
if (current <= end && *current)
|
||||
++current;
|
||||
llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
|
||||
return result;
|
||||
}
|
||||
const char *current, *end;
|
||||
};
|
||||
} // namespace toy
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_LEXER_H_
|
|
@ -0,0 +1,494 @@
|
|||
//===- Parser.h - Toy Language Parser -------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the parser for the Toy language. It processes the Token
|
||||
// provided by the Lexer and returns an AST.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TUTORIAL_TOY_PARSER_H
|
||||
#define MLIR_TUTORIAL_TOY_PARSER_H
|
||||
|
||||
#include "toy/AST.h"
|
||||
#include "toy/Lexer.h"
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace toy {
|
||||
|
||||
/// This is a simple recursive parser for the Toy language. It produces a well
|
||||
/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
|
||||
/// or symbol resolution is performed. For example, variables are referenced by
|
||||
/// string and the code could reference an undeclared variable and the parsing
|
||||
/// succeeds.
|
||||
class Parser {
|
||||
public:
|
||||
/// Create a Parser for the supplied lexer.
|
||||
Parser(Lexer &lexer) : lexer(lexer) {}
|
||||
|
||||
/// Parse a full Module. A module is a list of function definitions.
|
||||
std::unique_ptr<ModuleAST> ParseModule() {
|
||||
lexer.getNextToken(); // prime the lexer
|
||||
|
||||
// Parse functions one at a time and accumulate in this vector.
|
||||
std::vector<FunctionAST> functions;
|
||||
while (auto F = ParseDefinition()) {
|
||||
functions.push_back(std::move(*F));
|
||||
if (lexer.getCurToken() == tok_eof)
|
||||
break;
|
||||
}
|
||||
// If we didn't reach EOF, there was an error during parsing
|
||||
if (lexer.getCurToken() != tok_eof)
|
||||
return parseError<ModuleAST>("nothing", "at end of module");
|
||||
|
||||
return llvm::make_unique<ModuleAST>(std::move(functions));
|
||||
}
|
||||
|
||||
private:
|
||||
Lexer &lexer;
|
||||
|
||||
/// Parse a return statement.
|
||||
/// return :== return ; | return expr ;
|
||||
std::unique_ptr<ReturnExprAST> ParseReturn() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_return);
|
||||
|
||||
// return takes an optional argument
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
if (lexer.getCurToken() != ';') {
|
||||
expr = ParseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
}
|
||||
return llvm::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
|
||||
}
|
||||
|
||||
/// Parse a literal number.
|
||||
/// numberexpr ::= number
|
||||
std::unique_ptr<ExprAST> ParseNumberExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
auto Result =
|
||||
llvm::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
|
||||
lexer.consume(tok_number);
|
||||
return std::move(Result);
|
||||
}
|
||||
|
||||
/// Parse a literal array expression.
|
||||
/// tensorLiteral ::= [ literalList ] | number
|
||||
/// literalList ::= tensorLiteral | tensorLiteral, literalList
|
||||
std::unique_ptr<ExprAST> ParseTensorLitteralExpr() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(Token('['));
|
||||
|
||||
// Hold the list of values at this nesting level.
|
||||
std::vector<std::unique_ptr<ExprAST>> values;
|
||||
// Hold the dimensions for all the nesting inside this level.
|
||||
std::vector<int64_t> dims;
|
||||
do {
|
||||
// We can have either another nested array or a number literal.
|
||||
if (lexer.getCurToken() == '[') {
|
||||
values.push_back(ParseTensorLitteralExpr());
|
||||
if (!values.back())
|
||||
return nullptr; // parse error in the nested array.
|
||||
} else {
|
||||
if (lexer.getCurToken() != tok_number)
|
||||
return parseError<ExprAST>("<num> or [", "in literal expression");
|
||||
values.push_back(ParseNumberExpr());
|
||||
}
|
||||
|
||||
// End of this list on ']'
|
||||
if (lexer.getCurToken() == ']')
|
||||
break;
|
||||
|
||||
// Elements are separated by a comma.
|
||||
if (lexer.getCurToken() != ',')
|
||||
return parseError<ExprAST>("] or ,", "in literal expression");
|
||||
|
||||
lexer.getNextToken(); // eat ,
|
||||
} while (true);
|
||||
if (values.empty())
|
||||
return parseError<ExprAST>("<something>", "to fill literal expression");
|
||||
lexer.getNextToken(); // eat ]
|
||||
/// Fill in the dimensions now. First the current nesting level:
|
||||
dims.push_back(values.size());
|
||||
/// If there is any nested array, process all of them and ensure that
|
||||
/// dimensions are uniform.
|
||||
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
|
||||
return llvm::isa<LiteralExprAST>(expr.get());
|
||||
})) {
|
||||
auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
|
||||
if (!firstLiteral)
|
||||
return parseError<ExprAST>("uniform well-nested dimensions",
|
||||
"inside literal expession");
|
||||
|
||||
// Append the nested dimensions to the current level
|
||||
auto &firstDims = firstLiteral->getDims();
|
||||
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
|
||||
|
||||
// Sanity check that shape is uniform across all elements of the list.
|
||||
for (auto &expr : values) {
|
||||
auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
|
||||
if (!exprLiteral)
|
||||
return parseError<ExprAST>("uniform well-nested dimensions",
|
||||
"inside literal expession");
|
||||
if (exprLiteral->getDims() != firstDims)
|
||||
return parseError<ExprAST>("uniform well-nested dimensions",
|
||||
"inside literal expession");
|
||||
}
|
||||
}
|
||||
return llvm::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
|
||||
std::move(dims));
|
||||
}
|
||||
|
||||
/// parenexpr ::= '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseParenExpr() {
|
||||
lexer.getNextToken(); // eat (.
|
||||
auto V = ParseExpression();
|
||||
if (!V)
|
||||
return nullptr;
|
||||
|
||||
if (lexer.getCurToken() != ')')
|
||||
return parseError<ExprAST>(")", "to close expression with parentheses");
|
||||
lexer.consume(Token(')'));
|
||||
return V;
|
||||
}
|
||||
|
||||
/// identifierexpr
|
||||
/// ::= identifier
|
||||
/// ::= identifier '(' expression ')'
|
||||
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
|
||||
std::string name = lexer.getId();
|
||||
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.getNextToken(); // eat identifier.
|
||||
|
||||
if (lexer.getCurToken() != '(') // Simple variable ref.
|
||||
return llvm::make_unique<VariableExprAST>(std::move(loc), name);
|
||||
|
||||
// This is a function call.
|
||||
lexer.consume(Token('('));
|
||||
std::vector<std::unique_ptr<ExprAST>> Args;
|
||||
if (lexer.getCurToken() != ')') {
|
||||
while (true) {
|
||||
if (auto Arg = ParseExpression())
|
||||
Args.push_back(std::move(Arg));
|
||||
else
|
||||
return nullptr;
|
||||
|
||||
if (lexer.getCurToken() == ')')
|
||||
break;
|
||||
|
||||
if (lexer.getCurToken() != ',')
|
||||
return parseError<ExprAST>(", or )", "in argument list");
|
||||
lexer.getNextToken();
|
||||
}
|
||||
}
|
||||
lexer.consume(Token(')'));
|
||||
|
||||
// It can be a builtin call to print
|
||||
if (name == "print") {
|
||||
if (Args.size() != 1)
|
||||
return parseError<ExprAST>("<single arg>", "as argument to print()");
|
||||
|
||||
return llvm::make_unique<PrintExprAST>(std::move(loc),
|
||||
std::move(Args[0]));
|
||||
}
|
||||
|
||||
// Call to a user-defined function
|
||||
return llvm::make_unique<CallExprAST>(std::move(loc), name,
|
||||
std::move(Args));
|
||||
}
|
||||
|
||||
/// primary
|
||||
/// ::= identifierexpr
|
||||
/// ::= numberexpr
|
||||
/// ::= parenexpr
|
||||
/// ::= tensorliteral
|
||||
std::unique_ptr<ExprAST> ParsePrimary() {
|
||||
switch (lexer.getCurToken()) {
|
||||
default:
|
||||
llvm::errs() << "unknown token '" << lexer.getCurToken()
|
||||
<< "' when expecting an expression\n";
|
||||
return nullptr;
|
||||
case tok_identifier:
|
||||
return ParseIdentifierExpr();
|
||||
case tok_number:
|
||||
return ParseNumberExpr();
|
||||
case '(':
|
||||
return ParseParenExpr();
|
||||
case '[':
|
||||
return ParseTensorLitteralExpr();
|
||||
case ';':
|
||||
return nullptr;
|
||||
case '}':
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively parse the right hand side of a binary expression, the ExprPrec
|
||||
/// argument indicates the precedence of the current binary operator.
|
||||
///
|
||||
/// binoprhs ::= ('+' primary)*
|
||||
std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
|
||||
std::unique_ptr<ExprAST> LHS) {
|
||||
// If this is a binop, find its precedence.
|
||||
while (true) {
|
||||
int TokPrec = GetTokPrecedence();
|
||||
|
||||
// If this is a binop that binds at least as tightly as the current binop,
|
||||
// consume it, otherwise we are done.
|
||||
if (TokPrec < ExprPrec)
|
||||
return LHS;
|
||||
|
||||
// Okay, we know this is a binop.
|
||||
int BinOp = lexer.getCurToken();
|
||||
lexer.consume(Token(BinOp));
|
||||
auto loc = lexer.getLastLocation();
|
||||
|
||||
// Parse the primary expression after the binary operator.
|
||||
auto RHS = ParsePrimary();
|
||||
if (!RHS)
|
||||
return parseError<ExprAST>("expression", "to complete binary operator");
|
||||
|
||||
// If BinOp binds less tightly with RHS than the operator after RHS, let
|
||||
// the pending operator take RHS as its LHS.
|
||||
int NextPrec = GetTokPrecedence();
|
||||
if (TokPrec < NextPrec) {
|
||||
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
|
||||
if (!RHS)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Merge LHS/RHS.
|
||||
LHS = llvm::make_unique<BinaryExprAST>(std::move(loc), BinOp,
|
||||
std::move(LHS), std::move(RHS));
|
||||
}
|
||||
}
|
||||
|
||||
/// expression::= primary binoprhs
|
||||
std::unique_ptr<ExprAST> ParseExpression() {
|
||||
auto LHS = ParsePrimary();
|
||||
if (!LHS)
|
||||
return nullptr;
|
||||
|
||||
return ParseBinOpRHS(0, std::move(LHS));
|
||||
}
|
||||
|
||||
/// type ::= < shape_list >
|
||||
/// shape_list ::= num | num , shape_list
|
||||
std::unique_ptr<VarType> ParseType() {
|
||||
if (lexer.getCurToken() != '<')
|
||||
return parseError<VarType>("<", "to begin type");
|
||||
lexer.getNextToken(); // eat <
|
||||
|
||||
auto type = llvm::make_unique<VarType>();
|
||||
|
||||
while (lexer.getCurToken() == tok_number) {
|
||||
type->shape.push_back(lexer.getValue());
|
||||
lexer.getNextToken();
|
||||
if (lexer.getCurToken() == ',')
|
||||
lexer.getNextToken();
|
||||
}
|
||||
|
||||
if (lexer.getCurToken() != '>')
|
||||
return parseError<VarType>(">", "to end type");
|
||||
lexer.getNextToken(); // eat >
|
||||
return type;
|
||||
}
|
||||
|
||||
/// Parse a variable declaration, it starts with a `var` keyword followed by
|
||||
/// and identifier and an optional type (shape specification) before the
|
||||
/// initializer.
|
||||
/// decl ::= var identifier [ type ] = expr
|
||||
std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
|
||||
if (lexer.getCurToken() != tok_var)
|
||||
return parseError<VarDeclExprAST>("var", "to begin declaration");
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.getNextToken(); // eat var
|
||||
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<VarDeclExprAST>("identified",
|
||||
"after 'var' declaration");
|
||||
std::string id = lexer.getId();
|
||||
lexer.getNextToken(); // eat id
|
||||
|
||||
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
|
||||
if (lexer.getCurToken() == '<') {
|
||||
type = ParseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!type)
|
||||
type = llvm::make_unique<VarType>();
|
||||
lexer.consume(Token('='));
|
||||
auto expr = ParseExpression();
|
||||
return llvm::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
|
||||
std::move(*type), std::move(expr));
|
||||
}
|
||||
|
||||
/// Parse a block: a list of expression separated by semicolons and wrapped in
|
||||
/// curly braces.
|
||||
///
|
||||
/// block ::= { expression_list }
|
||||
/// expression_list ::= block_expr ; expression_list
|
||||
/// block_expr ::= decl | "return" | expr
|
||||
std::unique_ptr<ExprASTList> ParseBlock() {
|
||||
if (lexer.getCurToken() != '{')
|
||||
return parseError<ExprASTList>("{", "to begin block");
|
||||
lexer.consume(Token('{'));
|
||||
|
||||
auto exprList = llvm::make_unique<ExprASTList>();
|
||||
|
||||
// Ignore empty expressions: swallow sequences of semicolons.
|
||||
while (lexer.getCurToken() == ';')
|
||||
lexer.consume(Token(';'));
|
||||
|
||||
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
|
||||
if (lexer.getCurToken() == tok_var) {
|
||||
// Variable declaration
|
||||
auto varDecl = ParseDeclaration();
|
||||
if (!varDecl)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(varDecl));
|
||||
} else if (lexer.getCurToken() == tok_return) {
|
||||
// Return statement
|
||||
auto ret = ParseReturn();
|
||||
if (!ret)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(ret));
|
||||
} else {
|
||||
// General expression
|
||||
auto expr = ParseExpression();
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
exprList->push_back(std::move(expr));
|
||||
}
|
||||
// Ensure that elements are separated by a semicolon.
|
||||
if (lexer.getCurToken() != ';')
|
||||
return parseError<ExprASTList>(";", "after expression");
|
||||
|
||||
// Ignore empty expressions: swallow sequences of semicolons.
|
||||
while (lexer.getCurToken() == ';')
|
||||
lexer.consume(Token(';'));
|
||||
}
|
||||
|
||||
if (lexer.getCurToken() != '}')
|
||||
return parseError<ExprASTList>("}", "to close block");
|
||||
|
||||
lexer.consume(Token('}'));
|
||||
return exprList;
|
||||
}
|
||||
|
||||
/// prototype ::= def id '(' decl_list ')'
|
||||
/// decl_list ::= identifier | identifier, decl_list
|
||||
std::unique_ptr<PrototypeAST> ParsePrototype() {
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_def);
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<PrototypeAST>("function name", "in prototype");
|
||||
|
||||
std::string FnName = lexer.getId();
|
||||
lexer.consume(tok_identifier);
|
||||
|
||||
if (lexer.getCurToken() != '(')
|
||||
return parseError<PrototypeAST>("(", "in prototype");
|
||||
lexer.consume(Token('('));
|
||||
|
||||
std::vector<std::unique_ptr<VariableExprAST>> args;
|
||||
if (lexer.getCurToken() != ')') {
|
||||
do {
|
||||
std::string name = lexer.getId();
|
||||
auto loc = lexer.getLastLocation();
|
||||
lexer.consume(tok_identifier);
|
||||
auto decl = llvm::make_unique<VariableExprAST>(std::move(loc), name);
|
||||
args.push_back(std::move(decl));
|
||||
if (lexer.getCurToken() != ',')
|
||||
break;
|
||||
lexer.consume(Token(','));
|
||||
if (lexer.getCurToken() != tok_identifier)
|
||||
return parseError<PrototypeAST>(
|
||||
"identifier", "after ',' in function parameter list");
|
||||
} while (true);
|
||||
}
|
||||
if (lexer.getCurToken() != ')')
|
||||
return parseError<PrototypeAST>("}", "to end function prototype");
|
||||
|
||||
// success.
|
||||
lexer.consume(Token(')'));
|
||||
return llvm::make_unique<PrototypeAST>(std::move(loc), FnName,
|
||||
std::move(args));
|
||||
}
|
||||
|
||||
/// Parse a function definition, we expect a prototype initiated with the
|
||||
/// `def` keyword, followed by a block containing a list of expressions.
|
||||
///
|
||||
/// definition ::= prototype block
|
||||
std::unique_ptr<FunctionAST> ParseDefinition() {
|
||||
auto Proto = ParsePrototype();
|
||||
if (!Proto)
|
||||
return nullptr;
|
||||
|
||||
if (auto block = ParseBlock())
|
||||
return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(block));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Get the precedence of the pending binary operator token.
|
||||
int GetTokPrecedence() {
|
||||
if (!isascii(lexer.getCurToken()))
|
||||
return -1;
|
||||
|
||||
// 1 is lowest precedence.
|
||||
switch (static_cast<char>(lexer.getCurToken())) {
|
||||
case '-':
|
||||
return 20;
|
||||
case '+':
|
||||
return 20;
|
||||
case '*':
|
||||
return 40;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to signal errors while parsing, it takes an argument
|
||||
/// indicating the expected token and another argument giving more context.
|
||||
/// Location is retrieved from the lexer to enrich the error message.
|
||||
template <typename R, typename T, typename U = const char *>
|
||||
std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
|
||||
auto curToken = lexer.getCurToken();
|
||||
llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
|
||||
<< lexer.getLastLocation().col << "): expected '" << expected
|
||||
<< "' " << context << " but has Token " << curToken;
|
||||
if (isprint(curToken))
|
||||
llvm::errs() << " '" << (char)curToken << "'";
|
||||
llvm::errs() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace toy
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_PARSER_H
|
|
@ -0,0 +1,263 @@
|
|||
//===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the AST dump for the Toy language.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/AST.h"
|
||||
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace toy;
|
||||
|
||||
namespace {
|
||||
|
||||
// RAII helper to manage increasing/decreasing the indentation as we traverse
|
||||
// the AST
|
||||
struct Indent {
|
||||
Indent(int &level) : level(level) { ++level; }
|
||||
~Indent() { --level; }
|
||||
int &level;
|
||||
};
|
||||
|
||||
/// Helper class that implement the AST tree traversal and print the nodes along
|
||||
/// the way. The only data member is the current indentation level.
|
||||
class ASTDumper {
|
||||
public:
|
||||
void dump(ModuleAST *Node);
|
||||
|
||||
private:
|
||||
void dump(VarType &type);
|
||||
void dump(VarDeclExprAST *varDecl);
|
||||
void dump(ExprAST *expr);
|
||||
void dump(ExprASTList *exprList);
|
||||
void dump(NumberExprAST *num);
|
||||
void dump(LiteralExprAST *Node);
|
||||
void dump(VariableExprAST *Node);
|
||||
void dump(ReturnExprAST *Node);
|
||||
void dump(BinaryExprAST *Node);
|
||||
void dump(CallExprAST *Node);
|
||||
void dump(PrintExprAST *Node);
|
||||
void dump(PrototypeAST *Node);
|
||||
void dump(FunctionAST *Node);
|
||||
|
||||
// Actually print spaces matching the current indentation level
|
||||
void indent() {
|
||||
for (int i = 0; i < curIndent; i++)
|
||||
llvm::errs() << " ";
|
||||
}
|
||||
int curIndent = 0;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *Node) {
|
||||
const auto &loc = Node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
.str();
|
||||
}
|
||||
|
||||
// Helper Macro to bump the indentation level and print the leading spaces for
|
||||
// the current indentations
|
||||
#define INDENT() \
|
||||
Indent level_(curIndent); \
|
||||
indent();
|
||||
|
||||
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
|
||||
void ASTDumper::dump(ExprAST *expr) {
|
||||
#define dispatch(CLASS) \
|
||||
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
|
||||
return dump(node);
|
||||
dispatch(VarDeclExprAST);
|
||||
dispatch(LiteralExprAST);
|
||||
dispatch(NumberExprAST);
|
||||
dispatch(VariableExprAST);
|
||||
dispatch(ReturnExprAST);
|
||||
dispatch(BinaryExprAST);
|
||||
dispatch(CallExprAST);
|
||||
dispatch(PrintExprAST);
|
||||
// No match, fallback to a generic message
|
||||
INDENT();
|
||||
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
|
||||
}
|
||||
|
||||
/// A variable declaration is printing the variable name, the type, and then
|
||||
/// recurse in the initializer value.
|
||||
void ASTDumper::dump(VarDeclExprAST *varDecl) {
|
||||
INDENT();
|
||||
llvm::errs() << "VarDecl " << varDecl->getName();
|
||||
dump(varDecl->getType());
|
||||
llvm::errs() << " " << loc(varDecl) << "\n";
|
||||
dump(varDecl->getInitVal());
|
||||
}
|
||||
|
||||
/// A "block", or a list of expression
|
||||
void ASTDumper::dump(ExprASTList *exprList) {
|
||||
INDENT();
|
||||
llvm::errs() << "Block {\n";
|
||||
for (auto &expr : *exprList)
|
||||
dump(expr.get());
|
||||
indent();
|
||||
llvm::errs() << "} // Block\n";
|
||||
}
|
||||
|
||||
/// A literal number, just print the value.
|
||||
void ASTDumper::dump(NumberExprAST *num) {
|
||||
INDENT();
|
||||
llvm::errs() << num->getValue() << " " << loc(num) << "\n";
|
||||
}
|
||||
|
||||
/// Helper to print recurisvely a literal. This handles nested array like:
|
||||
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||
/// We print out such array with the dimensions spelled out at every level:
|
||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||
void printLitHelper(ExprAST *lit_or_num) {
|
||||
// Inside a literal expression we can have either a number or another literal
|
||||
if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
|
||||
llvm::errs() << num->getValue();
|
||||
return;
|
||||
}
|
||||
auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
|
||||
|
||||
// Print the dimension for this literal first
|
||||
llvm::errs() << "<";
|
||||
{
|
||||
const char *sep = "";
|
||||
for (auto dim : literal->getDims()) {
|
||||
llvm::errs() << sep << dim;
|
||||
sep = ", ";
|
||||
}
|
||||
}
|
||||
llvm::errs() << ">";
|
||||
|
||||
// Now print the content, recursing on every element of the list
|
||||
llvm::errs() << "[ ";
|
||||
const char *sep = "";
|
||||
for (auto &elt : literal->getValues()) {
|
||||
llvm::errs() << sep;
|
||||
printLitHelper(elt.get());
|
||||
sep = ", ";
|
||||
}
|
||||
llvm::errs() << "]";
|
||||
}
|
||||
|
||||
/// Print a literal, see the recursive helper above for the implementation.
|
||||
void ASTDumper::dump(LiteralExprAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Literal: ";
|
||||
printLitHelper(Node);
|
||||
llvm::errs() << " " << loc(Node) << "\n";
|
||||
}
|
||||
|
||||
/// Print a variable reference (just a name).
|
||||
void ASTDumper::dump(VariableExprAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
|
||||
}
|
||||
|
||||
/// Return statement print the return and its (optional) argument.
|
||||
void ASTDumper::dump(ReturnExprAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Return\n";
|
||||
if (Node->getExpr().hasValue())
|
||||
return dump(*Node->getExpr());
|
||||
{
|
||||
INDENT();
|
||||
llvm::errs() << "(void)\n";
|
||||
}
|
||||
}
|
||||
|
||||
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
|
||||
void ASTDumper::dump(BinaryExprAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
|
||||
dump(Node->getLHS());
|
||||
dump(Node->getRHS());
|
||||
}
|
||||
|
||||
/// Print a call expression, first the callee name and the list of args by
|
||||
/// recursing into each individual argument.
|
||||
void ASTDumper::dump(CallExprAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
|
||||
for (auto &arg : Node->getArgs())
|
||||
dump(arg.get());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a builtin print call, first the builtin name and then the argument.
|
||||
void ASTDumper::dump(PrintExprAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Print [ " << loc(Node) << "\n";
|
||||
dump(Node->getArg());
|
||||
indent();
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print type: only the shape is printed in between '<' and '>'
|
||||
void ASTDumper::dump(VarType &type) {
|
||||
llvm::errs() << "<";
|
||||
const char *sep = "";
|
||||
for (auto shape : type.shape) {
|
||||
llvm::errs() << sep << shape;
|
||||
sep = ", ";
|
||||
}
|
||||
llvm::errs() << ">";
|
||||
}
|
||||
|
||||
/// Print a function prototype, first the function name, and then the list of
|
||||
/// parameters names.
|
||||
void ASTDumper::dump(PrototypeAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
|
||||
indent();
|
||||
llvm::errs() << "Params: [";
|
||||
const char *sep = "";
|
||||
for (auto &arg : Node->getArgs()) {
|
||||
llvm::errs() << sep << arg->getName();
|
||||
sep = ", ";
|
||||
}
|
||||
llvm::errs() << "]\n";
|
||||
}
|
||||
|
||||
/// Print a function, first the prototype and then the body.
|
||||
void ASTDumper::dump(FunctionAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Function \n";
|
||||
dump(Node->getProto());
|
||||
dump(Node->getBody());
|
||||
}
|
||||
|
||||
/// Print a module, actually loop over the functions and print them in sequence.
|
||||
void ASTDumper::dump(ModuleAST *Node) {
|
||||
INDENT();
|
||||
llvm::errs() << "Module:\n";
|
||||
for (auto &F : *Node)
|
||||
dump(&F);
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
||||
// Public API
|
||||
void dump(ModuleAST &module) { ASTDumper().dump(&module); }
|
||||
|
||||
} // namespace toy
|
|
@ -0,0 +1,75 @@
|
|||
//===- toyc.cpp - The Toy Compiler ----------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the entry point for the Toy compiler.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Parser.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/ErrorOr.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace toy;
|
||||
namespace cl = llvm::cl;
|
||||
|
||||
static cl::opt<std::string> InputFilename(cl::Positional,
|
||||
cl::desc("<input toy file>"),
|
||||
cl::init("-"),
|
||||
cl::value_desc("filename"));
|
||||
namespace {
|
||||
enum Action { None, DumpAST };
|
||||
}
|
||||
|
||||
static cl::opt<enum Action>
|
||||
emitAction("emit", cl::desc("Select the kind of output desired"),
|
||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
|
||||
|
||||
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
||||
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||
if (std::error_code EC = FileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
auto buffer = FileOrErr.get()->getBuffer();
|
||||
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
|
||||
Parser parser(lexer);
|
||||
return parser.ParseModule();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
|
||||
|
||||
auto moduleAST = parseInputFile(InputFilename);
|
||||
if (!moduleAST)
|
||||
return 1;
|
||||
|
||||
switch (emitAction) {
|
||||
case Action::DumpAST:
|
||||
dump(*moduleAST);
|
||||
return 0;
|
||||
default:
|
||||
llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
# Chapter 1: Toy Tutorial Introduction
|
||||
|
||||
This tutorial runs through the implementation of a basic toy language on top of
|
||||
MLIR. The goal of this tutorial is to introduce the concepts of MLIR, and
|
||||
especially how *dialects* can help easily support language specific constructs
|
||||
and transformations, while still offering an easy path to lower to LLVM or other
|
||||
codegen infrastructure. This tutorial is based on the model of the
|
||||
[LLVM Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/LangImpl01.html).
|
||||
|
||||
This tutorial is divided in the following chapters:
|
||||
|
||||
- [Chapter #1](Ch-1.md): Introduction to the Toy language, and the definition
|
||||
of its AST.
|
||||
- [Chapter #2](Ch-2.md): Traversing the AST to emit custom MLIR, introducing
|
||||
base MLIR concepts.
|
||||
- [Chapter #3](Ch-3.md): Defining and registering a dialect in MLIR, showing
|
||||
how we can start attaching semantics to our custom operations in MLIR.
|
||||
- [Chapter #4](Ch-4.md): High-level language-specific analysis and
|
||||
transformation, showcasing shape inference, generic function specialization,
|
||||
and basic optimizations.
|
||||
- [Chapter #5](Ch-5.md): Lowering to lower-level dialects. We'll convert our
|
||||
high level language specific semantics towards a generic linear-algebra
|
||||
oriented dialect for optimizations. Ultimately we will emit LLVM IR for code
|
||||
generation.
|
||||
- [Chapter #5](Ch-6.md): A REPL?
|
||||
- [Chapter #6](Ch-7.md): Custom backends? GPU using LLVM? TPU? XLA
|
||||
|
||||
## The Language
|
||||
|
||||
This tutorial will be illustrated with a toy language that we’ll call “Toy”
|
||||
(naming is hard...). Toy is an array-based language that allows you to define
|
||||
functions, some math computation, and print results.
|
||||
|
||||
Because we want to keep things simple, the codegen will be limited to arrays of
|
||||
rank <= 2 and the only datatype in Toy is a 64-bit floating point type (aka
|
||||
‘double’ in C parlance). As such, all values are implicitly double precision,
|
||||
Values are immutable: every operation returns a newly allocated value, and
|
||||
deallocation is automatically managed. But enough with the long description,
|
||||
nothing is better than walking through an example to get a better understanding:
|
||||
|
||||
FIXME: update/modify matrix multiplication to use @ instead of *
|
||||
|
||||
```Toy {.toy}
|
||||
def main() {
|
||||
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
|
||||
# The shape is inferred from the supplied literal.
|
||||
var a = [[1, 2, 3], [4, 5, 6]];
|
||||
# b is identical to a, the literal array is implicitely reshaped: defining new
|
||||
# variables is the way to reshape arrays (element count must match).
|
||||
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||
# transpose() and print() are the only builtin, the following will transpose
|
||||
# b and perform a matrix multiplication before printing the result.
|
||||
print(a * transpose(b));
|
||||
}
|
||||
```
|
||||
|
||||
Type checking is statically performed through type inference, the language only
|
||||
requires type declarations to specify array shapes when needed. Function are
|
||||
generic: their parameters are unranked (in other word we know these are arrays
|
||||
but we don't know how many dimensions or the size of the dimensions). They are
|
||||
specialized for every newly discovered signature at call sites. Let's revisit
|
||||
the previous example by adding a user-defined function:
|
||||
|
||||
```Toy {.toy}
|
||||
# User defined generic function that operates on unknown shaped arguments
|
||||
def multiply_transpose(a, b) {
|
||||
return a * transpose(b);
|
||||
}
|
||||
|
||||
def main() {
|
||||
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
|
||||
var a = [[1, 2, 3], [4, 5, 6]];
|
||||
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||
# This call will specialize `multiply_transpose` with <2, 3> for both
|
||||
# arguments and deduce a return type of <2, 2> in initialization of `c`.
|
||||
var c = multiply_transpose(a, b);
|
||||
# A second call to `multiply_transpose` with <2, 3> for both arguments will
|
||||
# reuse the previously specialized and inferred version and return `<2, 2>`
|
||||
var d = multiply_transpose(b, a);
|
||||
# A new call with `<2, 2>` for both dimension will trigger another
|
||||
# specialization of `multiply_transpose`.
|
||||
var e = multiply_transpose(c, d);
|
||||
# Finally, calling into `multiply_transpose` with incompatible shape will
|
||||
# trigger a shape inference error.
|
||||
var e = multiply_transpose(transpose(a), c);
|
||||
}
|
||||
```
|
||||
|
||||
## The AST
|
||||
|
||||
The AST is fairly straightforward from the above code, here is a dump of it:
|
||||
|
||||
```
|
||||
Module:
|
||||
Function
|
||||
Proto 'multiply_transpose' @test/ast.toy:5:1'
|
||||
Args: [a, b]
|
||||
Block {
|
||||
Return
|
||||
BinOp: * @test/ast.toy:6:12
|
||||
var: a @test/ast.toy:6:10
|
||||
Call 'transpose' [ @test/ast.toy:6:14
|
||||
var: b @test/ast.toy:6:24
|
||||
]
|
||||
} // Block
|
||||
Function
|
||||
Proto 'main' @test/ast.toy:9:1'
|
||||
Args: []
|
||||
Block {
|
||||
VarDecl a<2, 3> @test/ast.toy:11:3
|
||||
Literal: <2, 3>[<3>[1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[4.000000e+00, 5.000000e+00, 6.000000e+00]] @test/ast.toy:11:17
|
||||
VarDecl b<2, 3> @test/ast.toy:12:3
|
||||
Literal: <6>[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @test/ast.toy:12:17
|
||||
VarDecl c<> @test/ast.toy:15:3
|
||||
Call 'multiply_transpose' [ @test/ast.toy:15:11
|
||||
var: a @test/ast.toy:15:30
|
||||
var: b @test/ast.toy:15:33
|
||||
]
|
||||
VarDecl d<> @test/ast.toy:18:3
|
||||
Call 'multiply_transpose' [ @test/ast.toy:18:11
|
||||
var: b @test/ast.toy:18:30
|
||||
var: a @test/ast.toy:18:33
|
||||
]
|
||||
VarDecl e<> @test/ast.toy:21:3
|
||||
Call 'multiply_transpose' [ @test/ast.toy:21:11
|
||||
var: b @test/ast.toy:21:30
|
||||
var: c @test/ast.toy:21:33
|
||||
]
|
||||
VarDecl e<> @test/ast.toy:24:3
|
||||
Call 'multiply_transpose' [ @test/ast.toy:24:11
|
||||
Call 'transpose' [ @test/ast.toy:24:30
|
||||
var: a @test/ast.toy:24:40
|
||||
]
|
||||
var: c @test/ast.toy:24:44
|
||||
]
|
||||
} // Block
|
||||
```
|
||||
|
||||
You can reproduce this result and play with the example in the `examples/Ch1/`
|
||||
directory, try running `path/to/BUILD/bin/toyc test/ast.toy -emit=ast`.
|
||||
|
||||
The code for the lexer is fairly straighforward, it is all in a single header:
|
||||
`examples/toy/Ch1/include/toy/Lexer.h`. The parser can be found in
|
||||
`examples/toy/Ch1/include/toy/Parser.h`, it is a recursive descent parser. If
|
||||
you are not familiar with such Lexer/Parser, these are very similar to the LLVM
|
||||
Kaleidoscope equivalent that are detailed in the first two chapters of the
|
||||
[Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/LangImpl02.html#the-abstract-syntax-tree-ast).
|
||||
|
||||
The [next chapter](Ch-2.md) will demonstrate how to convert this AST into MLIR.
|
|
@ -1,3 +1,8 @@
|
|||
llvm_canonicalize_cmake_booleans(
|
||||
LLVM_BUILD_EXAMPLES
|
||||
)
|
||||
|
||||
|
||||
configure_lit_site_cfg(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
|
||||
|
@ -20,6 +25,13 @@ set(MLIR_TEST_DEPENDS
|
|||
mlir-translate
|
||||
)
|
||||
|
||||
|
||||
if(LLVM_BUILD_EXAMPLES)
|
||||
list(APPEND MLIR_TEST_DEPENDS
|
||||
toyc-ch1
|
||||
)
|
||||
endif()
|
||||
|
||||
add_lit_testsuite(check-mlir "Running the MLIR regression tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${MLIR_TEST_DEPENDS}
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# RUN: toyc-ch1 %s -emit=ast 2>&1 | FileCheck %s
|
||||
|
||||
|
||||
# User defined generic function that operates solely on
|
||||
def multiply_transpose(a, b) {
|
||||
return a * transpose(b);
|
||||
}
|
||||
|
||||
def main() {
|
||||
# Define a variable `a` with shape <2, 3>, initialized with the literal value
|
||||
var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||
# This call will specialize `multiply_transpose` with <2, 3> for both
|
||||
# arguments and deduce a return type of <2, 2> in initialization of `c`.
|
||||
var c = multiply_transpose(a, b);
|
||||
# A second call to `multiply_transpose` with <2, 3> for both arguments will
|
||||
# reuse the previously specialized and inferred version and return `<2, 2>`
|
||||
var d = multiply_transpose(b, a);
|
||||
# A new call with `<2, 2>` for both dimension will trigger another
|
||||
# specialization of `multiply_transpose`.
|
||||
var e = multiply_transpose(b, c);
|
||||
# Finally, calling into `multiply_transpose` with incompatible shape will
|
||||
# trigger a shape inference error.
|
||||
var e = multiply_transpose(transpose(a), c);
|
||||
}
|
||||
|
||||
|
||||
# CHECK: Module:
|
||||
# CHECK-NEXT: Function
|
||||
# CHECK-NEXT: Proto 'multiply_transpose'
|
||||
# CHECK-NEXT: Params: [a, b]
|
||||
# CHECK-NEXT: Block {
|
||||
# CHECK-NEXT: Retur
|
||||
# CHECK-NEXT: BinOp: *
|
||||
# CHECK-NEXT: var: a
|
||||
# CHECK-NEXT: Call 'transpose' [
|
||||
# CHECK-NEXT: var: b
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: } // Block
|
||||
# CHECK-NEXT: Function
|
||||
# CHECK-NEXT: Proto 'main'
|
||||
# CHECK-NEXT: Params: []
|
||||
# CHECK-NEXT: Block {
|
||||
# CHECK-NEXT: VarDecl a<2, 3>
|
||||
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]]
|
||||
# CHECK-NEXT: VarDecl b<2, 3>
|
||||
# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]
|
||||
# CHECK-NEXT: VarDecl c<>
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [
|
||||
# CHECK-NEXT: var: a
|
||||
# CHECK-NEXT: var: b
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl d<>
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [
|
||||
# CHECK-NEXT: var: b
|
||||
# CHECK-NEXT: var: a
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl e<>
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [
|
||||
# CHECK-NEXT: var: b
|
||||
# CHECK-NEXT: var: c
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl e<>
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [
|
||||
# CHECK-NEXT: Call 'transpose' [
|
||||
# CHECK-NEXT: var: a
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: var: c
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: } // Block
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
if not config.build_examples:
|
||||
config.unsupported = True
|
|
@ -21,7 +21,7 @@ config.name = 'MLIR'
|
|||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.td', '.mlir']
|
||||
config.suffixes = ['.td', '.mlir', '.toy']
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
@ -54,4 +54,10 @@ tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
|
|||
tools = [
|
||||
'mlir-opt', 'mlir-tblgen', 'mlir-translate',
|
||||
]
|
||||
|
||||
# The following tools are optional
|
||||
tools.extend([
|
||||
ToolSubst('toy-ch1', unresolved='ignore'),
|
||||
])
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
|
|
@ -30,6 +30,7 @@ config.host_arch = "@HOST_ARCH@"
|
|||
config.mlir_src_root = "@MLIR_SOURCE_DIR@"
|
||||
config.mlir_obj_root = "@MLIR_BINARY_DIR@"
|
||||
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
|
||||
config.build_examples = @LLVM_BUILD_EXAMPLES@
|
||||
|
||||
# Support substitution of the tools_dir with user parameters. This is
|
||||
# used when we can't determine the tool dir at configuration time.
|
||||
|
|
Loading…
Reference in New Issue