Cleanup EDSCs and start a functional auto-generated library of custom Ops

This CL applies the following simplifications to EDSCs:
1. Rename Block to StmtList because an MLIR Block is a different, not yet
supported, notion;
2. Rework Bindable to drop specific storage and just use it as a simple wrapper
around Expr. The only value of Bindable is to force a static cast when used by
the user to bind into the emitter. For all intended purposes, Bindable is just
a lightweight check that an Expr is Unbound. This simplifies usage and reduces
the API footprint. After playing with it for some time, it wasn't worth the API
cognition overhead;
3. Replace makeExprs and makeBindables by makeNewExprs and copyExprs which is
more explicit and less easy to misuse;
4. Add generally useful functionality to MLIREmitter:
  a. expose zero and one for the ubiquitous common lower bounds and step;
  b. add support to create already bound Exprs for all function arguments as
  well as shapes and views for Exprs bound to memrefs.
5. Delete Stmt::operator= and replace by a `Stmt::set` method which is more
explicit.
6. Make Stmt::operator Expr() explicit.
7. Indexed.indices assertions are removed to pave the way for expressing slices
and views as well as to work with 0-D memrefs.

The CL plugs those simplifications with TableGen and allows emitting a full MLIR function for
pointwise add.

This "x.add" op is both type and rank-agnostic (by allowing ArrayRef of Expr
passed to For loops) and opens the door to spinning up a composable library of
existing and custom ops that should automate a lot of the tedious work in
TF/XLA -> MLIR.

Testing needs to be significantly improved but can be done in a separate CL.

PiperOrigin-RevId: 231982325
This commit is contained in:
Nicolas Vasilache 2019-02-01 09:16:31 -08:00 committed by jpienaar
parent 9f22a2391b
commit 0353ef99eb
11 changed files with 392 additions and 381 deletions

View File

@ -345,7 +345,7 @@ PYBIND11_MODULE(pybind, m) {
m.def("Block", [](const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonStmt(::Block(makeCStmts(owning, stmts)));
return PythonStmt(::StmtList(makeCStmts(owning, stmts)));
});
m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs,
const py::list &steps, const py::list &stmts) {

View File

@ -8,9 +8,6 @@ import unittest
import google_mlir.bindings.python.pybind as E
help(E)
class EdscTest(unittest.TestCase):
def testBindables(self):
@ -31,7 +28,7 @@ class EdscTest(unittest.TestCase):
loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
str = loop.__str__()
self.assertIn("for($1 = $2 to $3 step $4) {", str)
self.assertIn("$5 = ($2 + $3)", str)
self.assertIn(" = ($2 + $3)", str)
def testTwoLoops(self):
with E.ContextManager():
@ -66,7 +63,7 @@ class EdscTest(unittest.TestCase):
A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)]))
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
str = stmt.__str__()
self.assertIn(" = store( ... )", str)
self.assertIn(" = store(", str)
def testMatmul(self):
with E.ContextManager():
@ -84,7 +81,7 @@ class EdscTest(unittest.TestCase):
self.assertIn("for($1 = $4 to $7 step $10) {", str)
self.assertIn("for($2 = $5 to $8 step $11) {", str)
self.assertIn("for($3 = $6 to $9 step $12) {", str)
self.assertIn(" = store( ... )", str)
self.assertIn(" = store", str)
def testArithmetic(self):
with E.ContextManager():
@ -105,9 +102,9 @@ class EdscTest(unittest.TestCase):
i, j = list(map(E.Expr, [E.Bindable() for _ in range(2)]))
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
str = stmt.__str__()
self.assertIn("block {", str)
self.assertIn("$3 = ($1 + $2)", str)
self.assertIn("$4 = ($1 - $2)", str)
self.assertIn("stmt_list {", str)
self.assertIn(" = ($1 + $2)", str)
self.assertIn(" = ($1 - $2)", str)
self.assertIn("}", str)
def testMLIRScalarTypes(self):

View File

@ -218,7 +218,7 @@ edsc_stmt_t Return(edsc_expr_list_t values);
/// this is pure syntactic sugar to allow lists of mlir::edsc::Stmt to be
/// specified and emitted. In particular, block arguments are not currently
/// supported.
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts);
edsc_stmt_t StmtList(edsc_stmt_list_t enclosedStmts);
/// Returns an opaque statement for an mlir::ForInst with `enclosedStmts` nested
/// below it.

View File

@ -57,8 +57,7 @@ namespace edsc {
struct MLIREmitter {
using BindingMap = llvm::DenseMap<Expr, Value *>;
explicit MLIREmitter(FuncBuilder *builder, Location location)
: builder(builder), location(location) {}
MLIREmitter(FuncBuilder *builder, Location location);
FuncBuilder *getBuilder() { return builder; }
Location getLocation() { return location; }
@ -99,12 +98,6 @@ struct MLIREmitter {
return *this;
}
/// Emits the MLIR for `expr` and inserts at the `builder`'s insertion point.
/// This function must only be called once on a given emitter.
/// Prerequisites: all the Bindables have been bound.
Value *emit(Expr expr);
llvm::SmallVector<Value *, 8> emit(llvm::ArrayRef<Expr> exprs);
/// Emits the MLIR for `stmt` and inserts at the `builder`'s insertion point.
/// Prerequisites: all the Bindables have been bound.
void emitStmt(const Stmt &stmt);
@ -114,6 +107,15 @@ struct MLIREmitter {
/// Prerequisite: it must exist.
Value *getValue(Expr expr) { return ssaBindings.lookup(expr); }
/// Returns unique zero and one Expr that are bound to the corresponding
/// constant index value.
Expr zero() { return zeroIndex; }
Expr one() { return oneIndex; }
/// Returns a list of Expr that are bound to the function arguments.
SmallVector<edsc::Expr, 8>
makeBoundFunctionArguments(mlir::Function *function);
/// Returns a list of `Bindable` that are bound to the dimensions of the
/// memRef. The proper DimOp and ConstantOp are constructed at the current
/// insertion point in `builder`. They can be later hoisted and simplified in
@ -121,12 +123,55 @@ struct MLIREmitter {
///
/// Prerequisite:
/// `memRef` is a Value of type MemRefType.
SmallVector<edsc::Expr, 8> makeBoundSizes(Value *memRef);
SmallVector<edsc::Expr, 8> makeBoundMemRefShape(Value *memRef);
/// A BoundMemRefView represents the information required to step through a
/// memref. It has placeholders for non-contiguous tensors that fit within the
/// Fortran subarray model.
/// It is extracted from a memref
struct BoundMemRefView {
SmallVector<edsc::Expr, 8> lbs;
SmallVector<edsc::Expr, 8> ubs;
SmallVector<edsc::Expr, 8> steps;
edsc::Expr dim(unsigned idx) const { return ubs[idx]; }
unsigned rank() const { return lbs.size(); }
};
BoundMemRefView makeBoundMemRefView(Value *memRef);
BoundMemRefView makeBoundMemRefView(Expr memRef);
template <typename IterType>
SmallVector<BoundMemRefView, 8> makeBoundMemRefViews(IterType begin,
IterType end) {
static_assert(!std::is_same<decltype(*begin), Indexed>(),
"Must use Bindable or Expr");
SmallVector<mlir::edsc::MLIREmitter::BoundMemRefView, 8> res;
for (auto it = begin; it != end; ++it) {
res.push_back(makeBoundMemRefView(*it));
}
return res;
}
template <typename T>
SmallVector<BoundMemRefView, 8>
makeBoundMemRefViews(std::initializer_list<T> args) {
static_assert(!std::is_same<T, Indexed>(), "Must use Bindable or Expr");
SmallVector<mlir::edsc::MLIREmitter::BoundMemRefView, 8> res;
for (auto m : args) {
res.push_back(makeBoundMemRefView(m));
}
return res;
}
private:
/// Emits the MLIR for `expr` and inserts at the `builder`'s insertion point.
/// This function must only be called once on a given emitter.
/// Prerequisites: all the Bindables have been bound.
Value *emitExpr(Expr expr);
llvm::SmallVector<Value *, 8> emitExprs(llvm::ArrayRef<Expr> exprs);
FuncBuilder *builder;
Location location;
BindingMap ssaBindings;
// These are so ubiquitous that we make them bound and available to all.
Expr zeroIndex, oneIndex;
};
} // namespace edsc

View File

@ -41,7 +41,6 @@ namespace edsc {
namespace detail {
struct ExprStorage;
struct BindableStorage;
struct UnaryExprStorage;
struct BinaryExprStorage;
struct TernaryExprStorage;
@ -112,7 +111,7 @@ enum class ExprKind {
Return,
LAST_VARIADIC_EXPR = Return,
FIRST_STMT_BLOCK_LIKE_EXPR = 600,
Block = FIRST_STMT_BLOCK_LIKE_EXPR,
StmtList = FIRST_STMT_BLOCK_LIKE_EXPR,
For,
LAST_STMT_BLOCK_LIKE_EXPR = For,
LAST_NON_BINDABLE_EXPR = LAST_STMT_BLOCK_LIKE_EXPR,
@ -170,17 +169,14 @@ public:
return allocator;
}
Expr() : storage(nullptr) {}
Expr();
/* implicit */ Expr(ImplType *storage) : storage(storage) {}
explicit Expr(edsc_expr_t expr)
: storage(reinterpret_cast<ImplType *>(expr)) {}
operator edsc_expr_t() { return edsc_expr_t{storage}; }
Expr(const Expr &other) : storage(other.storage) {}
Expr &operator=(Expr other) {
storage = other.storage;
return *this;
}
Expr(const Expr &other) = default;
Expr &operator=(const Expr &other) = default;
explicit operator bool() { return storage; }
bool operator!() { return storage == nullptr; }
@ -193,6 +189,7 @@ public:
/// Returns the classification for this type.
ExprKind getKind() const;
unsigned getId() const;
void print(raw_ostream &os) const;
void dump() const;
@ -218,27 +215,26 @@ public:
friend ::llvm::hash_code hash_value(Expr arg);
protected:
friend struct detail::ExprStorage;
ImplType *storage;
static void resetIds() { newId() = 0; }
static unsigned &newId();
};
struct Bindable : public Expr {
using ImplType = detail::BindableStorage;
friend class Expr;
Bindable();
unsigned getId() const;
// protected:
Bindable(Expr::ImplType *ptr) : Expr(ptr) {
assert(!ptr || isa<Bindable>() && "expected Bindable");
Bindable() = delete;
Bindable(Expr expr) : Expr(expr) {
assert(expr.isa<Bindable>() && "expected Bindable");
}
Bindable(const Bindable &) = default;
Bindable &operator=(const Bindable &) = default;
explicit Bindable(const edsc_expr_t &expr) : Expr(expr) {}
operator edsc_expr_t() { return edsc_expr_t{storage}; }
friend struct ScopedEDSCContext;
private:
static void resetIds() { newId() = 0; }
static unsigned &newId();
friend class Expr;
friend struct ScopedEDSCContext;
};
struct UnaryExpr : public Expr {
@ -301,7 +297,6 @@ struct StmtBlockLikeExpr : public Expr {
StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
llvm::ArrayRef<Type> types = {});
llvm::ArrayRef<Expr> getExprs() const;
llvm::ArrayRef<Type> getTypes() const;
protected:
StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) {
@ -316,7 +311,7 @@ protected:
///
/// ```mlir
/// Stmt scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
/// Stmt block = Block({
/// Stmt block = StmtList({
/// tmpAlloc = alloc(tmpMemRefType),
/// vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
/// For(ivs, lbs, ubs, steps, {
@ -342,36 +337,38 @@ protected:
/// 1. `For`-loops for which the `lhs` binds to the induction variable, `rhs`
/// binds to an Expr of kind `ExprKind::For` with lower-bound, upper-bound and
/// step respectively;
/// 2. `Block` with an Expr of kind `ExprKind::Block` and which has no `rhs` but
/// 2. `StmtList` with an Expr of kind `ExprKind::StmtList` and which has no
/// `rhs` but
/// only `enclosingStmts`.
struct Stmt {
using ImplType = detail::StmtStorage;
friend class Expr;
Stmt() : storage(nullptr) {}
explicit Stmt(ImplType *storage) : storage(storage) {}
Stmt(const Stmt &other) : storage(other.storage) {}
Stmt operator=(const Stmt &other) {
this->storage = other.storage; // NBD if &other == this
return *this;
}
Stmt(const Stmt &other) = default;
Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
Stmt(const Bindable &lhs, const Expr &rhs,
llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
explicit operator Expr() const { return getLHS(); }
Stmt &operator=(const Expr &expr);
Stmt &set(const Stmt &other) {
this->storage = other.storage;
return *this;
}
Stmt &operator=(const Stmt &other) = delete;
explicit Stmt(edsc_stmt_t stmt)
: storage(reinterpret_cast<ImplType *>(stmt)) {}
operator edsc_stmt_t() { return edsc_stmt_t{storage}; }
operator Expr() const { return getLHS(); }
/// For debugging purposes.
const void *getStoragePtr() const { return storage; }
const ImplType *getStoragePtr() const { return storage; }
void print(raw_ostream &os, llvm::Twine indent = "") const;
void dump() const;
std::string str() const;
Bindable getLHS() const;
Expr getLHS() const;
Expr getRHS() const;
llvm::ArrayRef<Stmt> getEnclosedStmts() const;
@ -470,45 +467,41 @@ namespace edsc {
///
/// Since bindings are hashed by the underlying pointer address, we need to be
/// sure to construct new elements in a vector. We cannot just use
/// `llvm::SmallVector<Bindable, 8> dims(n);` directly because a single
/// `Bindable` will be default constructed and copied everywhere in the vector.
/// Hilarity ensues when trying to bind structs that are already bound.
llvm::SmallVector<Bindable, 8> makeBindables(unsigned n);
llvm::SmallVector<Expr, 8> makeExprs(unsigned n);
llvm::SmallVector<Expr, 8> makeExprs(ArrayRef<Bindable> bindables);
/// `llvm::SmallVector<Expr, 8> dims(n);` directly because a single
/// `Expr` will be default constructed and copied everywhere in the vector.
/// Hilarity ensues when trying to bind `Expr` multiple times.
llvm::SmallVector<Expr, 8> makeNewExprs(unsigned n);
template <typename IterTy>
llvm::SmallVector<Expr, 8> makeExprs(IterTy begin, IterTy end) {
llvm::SmallVector<Expr, 8> copyExprs(IterTy begin, IterTy end) {
return llvm::SmallVector<Expr, 8>(begin, end);
}
inline llvm::SmallVector<Expr, 8> copyExprs(llvm::ArrayRef<Expr> exprs) {
return llvm::SmallVector<Expr, 8>(exprs.begin(), exprs.end());
}
Expr alloc(llvm::ArrayRef<Expr> sizes, Type memrefType);
inline Expr alloc(Type memrefType) { return alloc({}, memrefType); }
Expr dealloc(Expr memref);
Expr load(Expr m, Expr index);
Expr load(Expr m, Bindable index);
Expr load(Expr m, llvm::ArrayRef<Expr> indices);
Expr load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices);
Expr store(Expr val, Expr m, Expr index);
Expr store(Expr val, Expr m, Bindable index);
Expr store(Expr val, Expr m, llvm::ArrayRef<Expr> indices);
Expr store(Expr val, Expr m, const llvm::SmallVectorImpl<Bindable> &indices);
Expr load(Expr m, llvm::ArrayRef<Expr> indices = {});
inline Expr load(Stmt m, llvm::ArrayRef<Expr> indices = {}) {
return load(m.getLHS(), indices);
}
Expr store(Expr val, Expr m, llvm::ArrayRef<Expr> indices = {});
inline Expr store(Stmt val, Expr m, llvm::ArrayRef<Expr> indices = {}) {
return store(val.getLHS(), m, indices);
}
Expr select(Expr cond, Expr lhs, Expr rhs);
Expr vector_type_cast(Expr memrefExpr, Type memrefType);
Stmt Return(ArrayRef<Expr> values);
Stmt Block(llvm::ArrayRef<Stmt> stmts);
Stmt Return(ArrayRef<Expr> values = {});
Stmt StmtList(llvm::ArrayRef<Stmt> stmts);
Stmt For(Expr lb, Expr ub, Expr step, llvm::ArrayRef<Stmt> enclosedStmts);
Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step,
llvm::ArrayRef<Stmt> enclosedStmts);
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Expr> lbs,
Stmt For(llvm::ArrayRef<Expr> indices, llvm::ArrayRef<Expr> lbs,
llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Expr> steps,
llvm::ArrayRef<Stmt> enclosedStmts);
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
llvm::ArrayRef<Bindable> ubs, llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts);
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts);
/// This helper class exists purely for sugaring purposes and allows writing
/// expressions such as:
@ -520,35 +513,30 @@ Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
/// });
/// ```
struct Indexed {
Indexed(Bindable b) : base(b), indices() {}
Indexed(Expr e) : base(e), indices() {}
/// Returns a new `Indexed`. As a consequence, an Indexed with attached
/// indices can never be reused unless it is captured (e.g. via a Stmt).
/// This is consistent with SSA behavior in MLIR but also allows for some
/// minimal state and sugaring.
Indexed operator[](llvm::ArrayRef<Expr> indices) const;
Indexed operator[](llvm::ArrayRef<Bindable> indices) const;
Indexed operator()(llvm::ArrayRef<Expr> indices = {});
/// Returns a new `Stmt`.
/// Emits a `store` and clears the attached indices.
Stmt operator=(Expr expr); // NOLINT: unconventional-assing-operator
/// Implicit conversion.
/// Emits a `load` and clears indices.
operator Expr() const {
assert(!indices.empty() && "Expected attached indices to Indexed");
return load(base, indices);
}
/// Emits a `load`.
operator Expr() { return load(base, indices); }
/// Operator overloadings.
Expr operator+(Expr e) const { return static_cast<Expr>(*this) + e; }
Expr operator-(Expr e) const { return static_cast<Expr>(*this) - e; }
Expr operator*(Expr e) const { return static_cast<Expr>(*this) * e; }
Expr operator+(Expr e) { return load(base, indices) + e; }
Expr operator-(Expr e) { return load(base, indices) - e; }
Expr operator*(Expr e) { return load(base, indices) * e; }
private:
Expr base;
llvm::SmallVector<Expr, 4> indices;
llvm::SmallVector<Expr, 8> indices;
};
} // namespace edsc

View File

@ -140,6 +140,13 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
}
}
mlir::edsc::MLIREmitter::MLIREmitter(FuncBuilder *builder, Location location)
: builder(builder), location(location) {
// Build the ubiquitous zero and one at the top of the function.
bindConstant<ConstantIndexOp>(Bindable(zeroIndex), 0);
bindConstant<ConstantIndexOp>(Bindable(oneIndex), 1);
}
MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
LLVM_DEBUG(printDefininingStatement(llvm::dbgs() << "\nBinding " << e << " @"
<< e.getStoragePtr() << ": ",
@ -153,7 +160,7 @@ MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
return *this;
}
Value *mlir::edsc::MLIREmitter::emit(Expr e) {
Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
auto it = ssaBindings.find(e);
if (it != ssaBindings.end()) {
return it->second;
@ -163,12 +170,12 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
Value *res = nullptr;
if (auto un = e.dyn_cast<UnaryExpr>()) {
if (un.getKind() == ExprKind::Dealloc) {
builder->create<DeallocOp>(location, emit(un.getExpr()));
builder->create<DeallocOp>(location, emitExpr(un.getExpr()));
return nullptr;
}
} else if (auto bin = e.dyn_cast<BinaryExpr>()) {
auto *a = emit(bin.getLHS());
auto *b = emit(bin.getRHS());
auto *a = emitExpr(bin.getLHS());
auto *b = emitExpr(bin.getRHS());
if (!a || !b) {
return nullptr;
}
@ -221,9 +228,9 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
if (auto ter = e.dyn_cast<TernaryExpr>()) {
if (ter.getKind() == ExprKind::Select) {
auto *cond = emit(ter.getCond());
auto *lhs = emit(ter.getLHS());
auto *rhs = emit(ter.getRHS());
auto *cond = emitExpr(ter.getCond());
auto *lhs = emitExpr(ter.getLHS());
auto *rhs = emitExpr(ter.getRHS());
if (!cond || !rhs || !lhs) {
return nullptr;
}
@ -233,7 +240,7 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
if (auto nar = e.dyn_cast<VariadicExpr>()) {
if (nar.getKind() == ExprKind::Alloc) {
auto exprs = emit(nar.getExprs());
auto exprs = emitExprs(nar.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
@ -243,26 +250,26 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
builder->create<AllocOp>(location, types[0].cast<MemRefType>(), exprs)
->getResult();
} else if (nar.getKind() == ExprKind::Load) {
auto exprs = emit(nar.getExprs());
auto exprs = emitExprs(nar.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
assert(exprs.size() > 1 && "Expected > 1 expr");
assert(nar.getTypes().empty() && "Expected no type");
assert(!exprs.empty() && "Load requires >= 1 exprs");
assert(nar.getTypes().empty() && "Load expects no type");
SmallVector<Value *, 8> vals(exprs.begin() + 1, exprs.end());
res = builder->create<LoadOp>(location, exprs[0], vals)->getResult();
} else if (nar.getKind() == ExprKind::Store) {
auto exprs = emit(nar.getExprs());
auto exprs = emitExprs(nar.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
assert(exprs.size() > 2 && "Expected > 2 expr");
assert(nar.getTypes().empty() && "Expected no type");
assert(exprs.size() >= 2 && "Store requires >= 2 exprs");
assert(nar.getTypes().empty() && "Store expects no type");
SmallVector<Value *, 8> vals(exprs.begin() + 2, exprs.end());
builder->create<StoreOp>(location, exprs[0], exprs[1], vals);
return nullptr;
} else if (nar.getKind() == ExprKind::VectorTypeCast) {
auto exprs = emit(nar.getExprs());
auto exprs = emitExprs(nar.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
@ -274,7 +281,7 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
types[0].cast<MemRefType>())
->getResult();
} else if (nar.getKind() == ExprKind::Return) {
auto exprs = emit(nar.getExprs());
auto exprs = emitExprs(nar.getExprs());
builder->create<ReturnOp>(location, exprs);
return nullptr; // no Value* produced and this is fine.
}
@ -282,12 +289,11 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) {
if (expr.getKind() == ExprKind::For) {
auto exprs = emit(expr.getExprs());
auto exprs = emitExprs(expr.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
assert(exprs.size() == 3 && "Expected 3 exprs");
assert(expr.getTypes().empty() && "Expected no type");
auto lb =
exprs[0]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
auto ub =
@ -318,23 +324,24 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
return res;
}
SmallVector<Value *, 8> mlir::edsc::MLIREmitter::emit(ArrayRef<Expr> exprs) {
return mlir::functional::map(
[this](Expr e) {
auto *res = this->emit(e);
SmallVector<Value *, 8>
mlir::edsc::MLIREmitter::emitExprs(ArrayRef<Expr> exprs) {
SmallVector<Value *, 8> res;
res.reserve(exprs.size());
for (auto e : exprs) {
res.push_back(this->emitExpr(e));
LLVM_DEBUG(
printDefininingStatement(llvm::dbgs() << "\nEmitted: ", *res));
printDefininingStatement(llvm::dbgs() << "\nEmitted: ", *res.back()));
}
return res;
},
exprs);
}
void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
auto *block = builder->getBlock();
auto ip = builder->getInsertionPoint();
// Blocks are just a containing abstraction, they do not emit their RHS.
if (stmt.getRHS().getKind() != ExprKind::Block) {
auto *val = emit(stmt.getRHS());
if (stmt.getRHS().getKind() != ExprKind::StmtList) {
auto *val = emitExpr(stmt.getRHS());
if (!val) {
assert((stmt.getRHS().getKind() == ExprKind::Dealloc ||
stmt.getRHS().getKind() == ExprKind::Store ||
@ -342,7 +349,8 @@ void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
"dealloc, store or return expected as the only 0-result ops");
return;
}
bind(stmt.getLHS(), val);
// Force create a bindable from stmt.lhs and bind it.
bind(Bindable(stmt.getLHS()), val);
if (stmt.getRHS().getKind() == ExprKind::For) {
// Step into the loop.
builder->setInsertionPointToStart(
@ -408,10 +416,23 @@ static SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
}
SmallVector<edsc::Expr, 8>
mlir::edsc::MLIREmitter::makeBoundSizes(Value *memRef) {
mlir::edsc::MLIREmitter::makeBoundFunctionArguments(mlir::Function *function) {
SmallVector<edsc::Expr, 8> res;
for (unsigned pos = 0, npos = function->getNumArguments(); pos < npos;
++pos) {
auto *arg = function->getArgument(pos);
Expr b;
bind(Bindable(b), arg);
res.push_back(Expr(b));
}
return res;
}
SmallVector<edsc::Expr, 8>
mlir::edsc::MLIREmitter::makeBoundMemRefShape(Value *memRef) {
assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
MemRefType memRefType = memRef->getType().cast<MemRefType>();
auto memRefSizes = edsc::makeBindables(memRefType.getShape().size());
auto memRefSizes = edsc::makeNewExprs(memRefType.getShape().size());
auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef);
assert(memrefSizeValues.size() == memRefSizes.size());
bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
@ -419,37 +440,71 @@ mlir::edsc::MLIREmitter::makeBoundSizes(Value *memRef) {
return res;
}
mlir::edsc::MLIREmitter::BoundMemRefView
mlir::edsc::MLIREmitter::makeBoundMemRefView(Value *memRef) {
auto memRefType = memRef->getType().cast<mlir::MemRefType>();
auto rank = memRefType.getRank();
SmallVector<edsc::Expr, 8> lbs;
lbs.reserve(rank);
Expr zero;
bindConstant<mlir::ConstantIndexOp>(Bindable(zero), 0);
for (unsigned i = 0; i < rank; ++i) {
lbs.push_back(zero);
}
auto ubs = makeBoundMemRefShape(memRef);
SmallVector<edsc::Expr, 8> steps;
lbs.reserve(rank);
Expr one;
bindConstant<mlir::ConstantIndexOp>(Bindable(one), 1);
for (unsigned i = 0; i < rank; ++i) {
steps.push_back(one);
}
return BoundMemRefView{lbs, ubs, steps};
}
mlir::edsc::MLIREmitter::BoundMemRefView
mlir::edsc::MLIREmitter::makeBoundMemRefView(Expr boundMemRef) {
auto *v = getValue(mlir::edsc::Expr(boundMemRef));
assert(v && "Expected a bound Expr");
return makeBoundMemRefView(v);
}
edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
Expr b;
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
e->getBuilder()->getBF16Type());
return b;
}
edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
Expr b;
bool unused;
mlir::APFloat val(value);
val.convert(e->getBuilder()->getF16Type().getFloatSemantics(),
mlir::APFloat::rmNearestTiesToEven, &unused);
e->bindConstant<mlir::ConstantFloatOp>(b, val, e->getBuilder()->getF16Type());
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), val,
e->getBuilder()->getF16Type());
return b;
}
edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
Expr b;
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
e->getBuilder()->getF32Type());
return b;
}
edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
Expr b;
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
e->getBuilder()->getF64Type());
return b;
}
@ -457,7 +512,7 @@ edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
unsigned bitwidth) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
Expr b;
e->bindConstant<mlir::ConstantIntOp>(
b, value, e->getBuilder()->getIntegerType(bitwidth));
return b;
@ -465,8 +520,8 @@ edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantIndexOp>(b, value);
Expr b;
e->bindConstant<mlir::ConstantIndexOp>(Bindable(b), value);
return b;
}
@ -493,8 +548,8 @@ edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
auto *f = reinterpret_cast<mlir::Function *>(function);
assert(pos < f->getNumArguments());
auto *arg = *(f->getArguments().begin() + pos);
Bindable b;
e->bind(b, arg);
Expr b;
e->bind(Bindable(b), arg);
return Expr(b);
}
@ -505,8 +560,8 @@ void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
assert(result->n == f->getNumArguments());
for (unsigned pos = 0; pos < result->n; ++pos) {
auto *arg = *(f->getArguments().begin() + pos);
Bindable b;
e->bind(b, arg);
Expr b;
e->bind(Bindable(b), arg);
result->exprs[pos] = Expr(b);
}
}
@ -515,6 +570,7 @@ unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter,
edsc_expr_t boundMemRef) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
assert(v && "Expected a bound Expr");
auto memRefType = v->getType().cast<mlir::MemRefType>();
return memRefType.getRank();
}
@ -523,10 +579,11 @@ void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
edsc_expr_list_t *result) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
assert(v && "Expected a bound Expr");
auto memRefType = v->getType().cast<mlir::MemRefType>();
auto rank = memRefType.getRank();
assert(result->n == rank && "Unexpected memref shape binding results count");
auto bindables = e->makeBoundSizes(v);
auto bindables = e->makeBoundMemRefShape(v);
for (unsigned i = 0; i < rank; ++i) {
result->exprs[i] = bindables[i];
}
@ -542,14 +599,14 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
assert(resultLbs->n == rank && "Unexpected memref binding results count");
assert(resultUbs->n == rank && "Unexpected memref binding results count");
assert(resultSteps->n == rank && "Unexpected memref binding results count");
auto bindables = e->makeBoundSizes(v);
for (unsigned i = 0; i < rank; ++i) {
Bindable zero;
auto bindables = e->makeBoundMemRefShape(v);
Expr zero;
e->bindConstant<mlir::ConstantIndexOp>(zero, 0);
Expr one;
e->bindConstant<mlir::ConstantIndexOp>(one, 1);
for (unsigned i = 0; i < rank; ++i) {
resultLbs->exprs[i] = zero;
resultUbs->exprs[i] = bindables[i];
Bindable one;
e->bindConstant<mlir::ConstantIndexOp>(one, 1);
resultSteps->exprs[i] = one;
}
}

View File

@ -39,12 +39,9 @@ namespace edsc {
namespace detail {
struct ExprStorage {
ExprStorage(ExprKind kind) : kind(kind) {}
ExprStorage(ExprKind kind, unsigned id = Expr::newId())
: kind(kind), id(id) {}
ExprKind kind;
};
struct BindableStorage : public ExprStorage {
BindableStorage(unsigned id) : ExprStorage(ExprKind::Unbound), id(id) {}
unsigned id;
};
@ -94,8 +91,23 @@ mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
Expr::globalAllocator() = nullptr;
}
mlir::edsc::Expr::Expr() {
// Initialize with placement new.
storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
new (storage) detail::ExprStorage(ExprKind::Unbound);
}
ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; }
unsigned mlir::edsc::Expr::getId() const {
return static_cast<ImplType *>(storage)->id;
}
unsigned &mlir::edsc::Expr::newId() {
static thread_local unsigned id = 0;
return ++id;
}
Expr mlir::edsc::Expr::operator+(Expr other) const {
return BinaryExpr(ExprKind::Add, *this, other);
}
@ -132,25 +144,7 @@ Expr mlir::edsc::Expr::operator||(Expr other) const {
}
// Free functions.
llvm::SmallVector<Bindable, 8> mlir::edsc::makeBindables(unsigned n) {
llvm::SmallVector<Bindable, 8> res;
res.reserve(n);
for (auto i = 0; i < n; ++i) {
res.push_back(Bindable());
}
return res;
}
static llvm::SmallVector<Bindable, 8> makeBindables(edsc_expr_list_t exprList) {
llvm::SmallVector<Bindable, 8> exprs;
exprs.reserve(exprList.n);
for (unsigned i = 0; i < exprList.n; ++i) {
exprs.push_back(Expr(exprList.exprs[i]).cast<Bindable>());
}
return exprs;
}
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(unsigned n) {
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) {
llvm::SmallVector<Expr, 8> res;
res.reserve(n);
for (auto i = 0; i < n; ++i) {
@ -159,15 +153,6 @@ llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(unsigned n) {
return res;
}
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(ArrayRef<Bindable> bindables) {
llvm::SmallVector<Expr, 8> res;
res.reserve(bindables.size());
for (auto b : bindables) {
res.push_back(b);
}
return res;
}
static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
llvm::SmallVector<Expr, 8> exprs;
exprs.reserve(exprList.n);
@ -177,25 +162,26 @@ static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
return exprs;
}
static llvm::SmallVector<Stmt, 8> makeStmts(edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
stmts.reserve(enclosedStmts.n);
static void fillStmts(edsc_stmt_list_t enclosedStmts,
llvm::SmallVector<Stmt, 8> *stmts) {
stmts->reserve(enclosedStmts.n);
for (unsigned i = 0; i < enclosedStmts.n; ++i) {
stmts.push_back(Stmt(enclosedStmts.stmts[i]));
stmts->push_back(Stmt(enclosedStmts.stmts[i]));
}
return stmts;
}
Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
return VariadicExpr(ExprKind::Alloc, sizes, memrefType);
}
Stmt mlir::edsc::Block(ArrayRef<Stmt> stmts) {
return Stmt(StmtBlockLikeExpr(ExprKind::Block, {}), stmts);
Stmt mlir::edsc::StmtList(ArrayRef<Stmt> stmts) {
return Stmt(StmtBlockLikeExpr(ExprKind::StmtList, {}), stmts);
}
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts) {
return Stmt(mlir::edsc::Block(makeStmts(enclosedStmts)));
edsc_stmt_t StmtList(edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
fillStmts(enclosedStmts, &stmts);
return Stmt(mlir::edsc::StmtList(stmts));
}
Expr mlir::edsc::dealloc(Expr memref) {
@ -203,8 +189,8 @@ Expr mlir::edsc::dealloc(Expr memref) {
}
Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
Bindable idx;
return For(idx, lb, ub, step, stmts);
Expr idx;
return For(Bindable(idx), lb, ub, step, stmts);
}
Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
@ -212,105 +198,68 @@ Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts);
}
Stmt mlir::edsc::For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
ArrayRef<Stmt> enclosedStmts) {
assert(!indices.empty());
assert(indices.size() == lbs.size());
assert(indices.size() == ubs.size());
assert(indices.size() == steps.size());
Expr iv = indices.back();
Stmt curStmt =
For(indices.back(), lbs.back(), ubs.back(), steps.back(), enclosedStmts);
For(Bindable(iv), lbs.back(), ubs.back(), steps.back(), enclosedStmts);
for (int64_t i = indices.size() - 2; i >= 0; --i) {
curStmt = For(indices[i], lbs[i], ubs[i], steps[i], {curStmt});
Expr iiv = indices[i];
curStmt.set(For(Bindable(iiv), lbs[i], ubs[i], steps[i],
llvm::ArrayRef<Stmt>{&curStmt, 1}));
}
return curStmt;
}
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Bindable> ubs,
llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts) {
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()},
SmallVector<Expr, 8>{ubs.begin(), ubs.end()},
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
}
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Expr> ubs,
llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts) {
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()}, ubs,
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
}
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
return Stmt(For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step),
makeStmts(enclosedStmts)));
llvm::SmallVector<Stmt, 8> stmts;
fillStmts(enclosedStmts, &stmts);
return Stmt(
For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step), stmts));
}
edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
edsc_expr_list_t ubs, edsc_expr_list_t steps,
edsc_stmt_list_t enclosedStmts) {
auto bindables = makeBindables(ivs);
return Stmt(For(bindables, makeExprs(lbs), makeExprs(ubs), makeExprs(steps),
makeStmts(enclosedStmts)));
llvm::SmallVector<Stmt, 8> stmts;
fillStmts(enclosedStmts, &stmts);
return Stmt(For(makeExprs(ivs), makeExprs(lbs), makeExprs(ubs),
makeExprs(steps), stmts));
}
template <typename BindableOrExpr>
static Expr loadBuilder(Expr m, ArrayRef<BindableOrExpr> indices) {
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
SmallVector<Expr, 8> exprs;
exprs.push_back(m);
exprs.append(indices.begin(), indices.end());
return VariadicExpr(ExprKind::Load, exprs);
}
Expr mlir::edsc::load(Expr m, Expr index) {
return loadBuilder<Expr>(m, {index});
}
Expr mlir::edsc::load(Expr m, Bindable index) {
return loadBuilder<Bindable>(m, {index});
}
Expr mlir::edsc::load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
return loadBuilder(m, ArrayRef<Bindable>{indices.begin(), indices.end()});
}
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
return loadBuilder(m, indices);
}
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) {
Indexed i(Expr(indexed.base).cast<Bindable>());
Expr res = i[makeExprs(indices)];
auto exprs = makeExprs(indices);
Expr res = i(exprs);
return res;
}
template <typename BindableOrExpr>
static Expr storeBuilder(Expr val, Expr m, ArrayRef<BindableOrExpr> indices) {
Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
SmallVector<Expr, 8> exprs;
exprs.push_back(val);
exprs.push_back(m);
exprs.append(indices.begin(), indices.end());
return VariadicExpr(ExprKind::Store, exprs);
}
Expr mlir::edsc::store(Expr val, Expr m, Expr index) {
return storeBuilder<Expr>(val, m, {index});
}
Expr mlir::edsc::store(Expr val, Expr m, Bindable index) {
return storeBuilder<Bindable>(val, m, {index});
}
Expr mlir::edsc::store(Expr val, Expr m,
const llvm::SmallVectorImpl<Bindable> &indices) {
return storeBuilder(val, m,
ArrayRef<Bindable>{indices.begin(), indices.end()});
}
Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
return storeBuilder(val, m, indices);
}
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
edsc_expr_list_t indices) {
Indexed i(Expr(indexed.base).cast<Bindable>());
Indexed loc = i[makeExprs(indices)];
auto exprs = makeExprs(indices);
Indexed loc = i(exprs);
return Stmt(loc = Expr(value));
}
@ -384,15 +333,20 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
}
}
} else if (auto nar = this->dyn_cast<VariadicExpr>()) {
auto exprs = nar.getExprs();
switch (nar.getKind()) {
case ExprKind::Load:
os << "load( ... )";
os << "load(" << exprs[0] << "[";
interleaveComma(ArrayRef<Expr>(exprs.begin() + 1, exprs.size() - 1), os);
os << "])";
return;
case ExprKind::Store:
os << "store( ... )";
os << "store(" << exprs[0] << ", " << exprs[1] << "[";
interleaveComma(ArrayRef<Expr>(exprs.begin() + 2, exprs.size() - 2), os);
os << "])";
return;
case ExprKind::Return:
interleaveComma(nar.getExprs(), os);
interleaveComma(exprs, os);
return;
default: {
os << "unknown_variadic";
@ -430,22 +384,7 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
return os;
}
mlir::edsc::Bindable::Bindable()
: Expr(Expr::globalAllocator()->Allocate<detail::BindableStorage>()) {
// Initialize with placement new.
new (storage) detail::BindableStorage{Bindable::newId()};
}
edsc_expr_t makeBindable() { return Bindable(); }
unsigned mlir::edsc::Bindable::getId() const {
return static_cast<ImplType *>(storage)->id;
}
unsigned &mlir::edsc::Bindable::newId() {
static thread_local unsigned id = 0;
return ++id;
}
edsc_expr_t makeBindable() { return Bindable(Expr()); }
mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
: Expr(Expr::globalAllocator()->Allocate<detail::UnaryExprStorage>()) {
@ -519,9 +458,6 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
return static_cast<ImplType *>(storage)->exprs;
}
ArrayRef<Type> mlir::edsc::StmtBlockLikeExpr::getTypes() const {
return static_cast<ImplType *>(storage)->types;
}
mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
llvm::ArrayRef<Stmt> enclosedStmts) {
@ -536,7 +472,7 @@ mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
}
mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
: Stmt(Bindable(), rhs, enclosedStmts) {}
: Stmt(Bindable(Expr()), rhs, enclosedStmts) {}
edsc_stmt_t makeStmt(edsc_expr_t e) {
assert(e && "unexpected empty expression");
@ -544,12 +480,12 @@ edsc_stmt_t makeStmt(edsc_expr_t e) {
}
Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
Stmt res(Bindable(), expr, {});
Stmt res(Bindable(Expr()), expr, {});
std::swap(res.storage, this->storage);
return *this;
}
Bindable mlir::edsc::Stmt::getLHS() const {
Expr mlir::edsc::Stmt::getLHS() const {
return static_cast<ImplType *>(storage)->lhs;
}
@ -562,7 +498,10 @@ llvm::ArrayRef<Stmt> mlir::edsc::Stmt::getEnclosedStmts() const {
}
void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const {
assert(storage && "Unexpected null storage,stmt must be bound to print");
if (!storage) {
os << "null_storage";
return;
}
auto lhs = getLHS();
auto rhs = getRHS();
@ -580,8 +519,8 @@ void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const {
}
os << indent << "}";
return;
case ExprKind::Block:
os << indent << "block {";
case ExprKind::StmtList:
os << indent << "stmt_list {";
for (auto &s : getEnclosedStmts()) {
os << "\n";
s.print(os, indent + " ");
@ -613,24 +552,15 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
return os;
}
Indexed mlir::edsc::Indexed::operator[](llvm::ArrayRef<Expr> indices) const {
Indexed mlir::edsc::Indexed::operator()(llvm::ArrayRef<Expr> indices) {
Indexed res(base);
res.indices = llvm::SmallVector<Expr, 4>(indices.begin(), indices.end());
return res;
}
Indexed mlir::edsc::Indexed::
operator[](llvm::ArrayRef<Bindable> indices) const {
return (*this)[llvm::ArrayRef<Expr>{indices.begin(), indices.end()}];
}
// NOLINTNEXTLINE: unconventional-assign-operator
Stmt mlir::edsc::Indexed::operator=(Expr expr) {
assert(!indices.empty() && "Expected attached indices to Indexed");
assert(base);
Stmt stmt(store(expr, base, indices));
indices.clear();
return stmt;
return Stmt(store(expr, base, indices));
}
edsc_indexed_t makeIndexed(edsc_expr_t expr) {

View File

@ -68,7 +68,7 @@ namespace {
/// local storage.
struct VectorTransferAccessInfo {
// `ivs` are bound for `For` Stmt at `For` Stmt construction time.
llvm::SmallVector<edsc::Bindable, 8> ivs;
llvm::SmallVector<edsc::Expr, 8> ivs;
llvm::SmallVector<edsc::Expr, 8> lowerBoundsExprs;
llvm::SmallVector<edsc::Expr, 8> upperBoundsExprs;
llvm::SmallVector<edsc::Expr, 8> stepExprs;
@ -107,15 +107,15 @@ private:
/// buffer.
MemRefType vectorMemRefType;
// EDSC `emitter` and Bindables that are pre-bound at construction time.
// vectorSizes are bound to the actual constant sizes of vectorType.
llvm::SmallVector<edsc::Bindable, 8> vectorSizes;
// accesses are bound to transfer->getIndices()
llvm::SmallVector<edsc::Bindable, 8> accesses;
// `zero` and `one` are bound to locally scoped constants.
// `scalarMemRef` is bound to `transfer->getMemRef()`.
edsc::Bindable zero, one, scalarMemRef;
// EDSC `emitter` and Exprs that are pre-bound at construction time.
edsc::MLIREmitter emitter;
// vectorSizes are bound to the actual constant sizes of vectorType.
llvm::SmallVector<edsc::Expr, 8> vectorSizes;
// accesses are bound to transfer->getIndices()
llvm::SmallVector<edsc::Expr, 8> accesses;
// `zero` and `one` are bound emitter.zero() and emitter.one().
// `scalarMemRef` is bound to `transfer->getMemRef()`.
edsc::Expr zero, one, scalarMemRef;
};
} // end anonymous namespace
@ -164,19 +164,19 @@ VectorTransferAccessInfo
VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
using namespace mlir::edsc;
// Create Bindable objects for ivs, they will be bound at `For` Stmt
// Create new Exprs for ivs, they will be bound at `For` Stmt
// construction.
auto ivs = makeBindables(vectorShape.size());
auto ivs = makeNewExprs(vectorShape.size());
// Create and bind Bindables to refer to the Value for memref sizes.
auto memRefSizes = emitter.makeBoundSizes(transfer->getMemRef());
// Create and bind Exprs to refer to the Value for memref sizes.
auto memRefSizes = emitter.makeBoundMemRefShape(transfer->getMemRef());
// Create the edsc::Expr for the clipped and transposes access expressions
// using the permutationMap. Additionally, capture the index accessing the
// most minor dimension.
int coalescingIndex = -1;
auto clippedScalarAccessExprs = makeExprs(accesses);
auto tmpAccessExprs = makeExprs(ivs);
auto clippedScalarAccessExprs = copyExprs(accesses);
auto tmpAccessExprs = copyExprs(ivs);
llvm::DenseSet<unsigned> clipped;
for (auto it : llvm::enumerate(permutationMap.getResults())) {
if (auto affineExpr = it.value().template dyn_cast<AffineDimExpr>()) {
@ -221,9 +221,9 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
// Create the proper bindables for lbs, ubs and steps. Additionally, if we
// recorded a coalescing index, permute the loop informations.
auto lbs = makeBindables(ivs.size());
auto ubs = makeExprs(vectorSizes);
auto steps = makeBindables(ivs.size());
auto lbs = makeNewExprs(ivs.size());
auto ubs = copyExprs(vectorSizes);
auto steps = makeNewExprs(ivs.size());
if (coalescingIndex >= 0) {
std::swap(ivs[coalescingIndex], ivs.back());
std::swap(lbs[coalescingIndex], lbs.back());
@ -237,9 +237,9 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
llvm::zip(steps, SmallVector<int64_t, 8>(ivs.size(), 1)));
return VectorTransferAccessInfo{ivs,
makeExprs(lbs),
copyExprs(lbs),
ubs,
makeExprs(steps),
copyExprs(steps),
clippedScalarAccessExprs,
tmpAccessExprs};
}
@ -255,14 +255,13 @@ VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
tmpMemRefType(
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0)),
vectorMemRefType(MemRefType::get({1}, vectorType, {}, 0)),
vectorSizes(edsc::makeBindables(vectorShape.size())),
emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())) {
emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())),
vectorSizes(edsc::makeNewExprs(vectorShape.size())), zero(emitter.zero()),
one(emitter.one()) {
// Bind the Bindable.
SmallVector<Value *, 8> transferIndices(transfer->getIndices());
accesses = edsc::makeBindables(transferIndices.size());
emitter.bind(scalarMemRef, transfer->getMemRef())
.template bindConstant<ConstantIndexOp>(zero, 0)
.template bindConstant<ConstantIndexOp>(one, 1)
accesses = edsc::makeNewExprs(transferIndices.size());
emitter.bind(edsc::Bindable(scalarMemRef), transfer->getMemRef())
.template bindZipRangeConstants<ConstantIndexOp>(
llvm::zip(vectorSizes, vectorShape))
.template bindZipRange(llvm::zip(accesses, transfer->getIndices()));
@ -321,23 +320,24 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
auto &lbs = accessInfo.lowerBoundsExprs;
auto &ubs = accessInfo.upperBoundsExprs;
auto &steps = accessInfo.stepExprs;
Stmt scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
Stmt block = edsc::Block({
Expr scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
Stmt block = edsc::StmtList({
tmpAlloc = alloc(tmpMemRefType),
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
vectorView = vector_type_cast(Expr(tmpAlloc), vectorMemRefType),
For(ivs, lbs, ubs, steps, {
scalarValue = load(scalarMemRef, accessInfo.clippedScalarAccessExprs),
store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs),
}),
vectorValue = load(vectorView, {zero}),
tmpDealloc = dealloc(tmpAlloc.getLHS())});
tmpDealloc = dealloc(tmpAlloc)
});
// clang-format on
// Emit the MLIR.
emitter.emitStmt(block);
// Finalize rewriting.
transfer->replaceAllUsesWith(emitter.getValue(vectorValue.getLHS()));
transfer->replaceAllUsesWith(emitter.getValue(vectorValue));
transfer->erase();
}
@ -367,24 +367,24 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
auto accessInfo = makeVectorTransferAccessInfo();
// Bind vector value for the vector_transfer_write.
Bindable vectorValue;
emitter.bind(vectorValue, transfer->getVector());
Expr vectorValue;
emitter.bind(Bindable(vectorValue), transfer->getVector());
// clang-format off
auto &ivs = accessInfo.ivs;
auto &lbs = accessInfo.lowerBoundsExprs;
auto &ubs = accessInfo.upperBoundsExprs;
auto &steps = accessInfo.stepExprs;
Stmt scalarValue, tmpAlloc, tmpDealloc, vectorView;
Stmt block = edsc::Block({
Expr scalarValue, tmpAlloc, tmpDealloc, vectorView;
Stmt block(edsc::StmtList({
tmpAlloc = alloc(tmpMemRefType),
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
store(vectorValue, vectorView, {zero}),
store(vectorValue, vectorView, MutableArrayRef<Expr>{zero}),
For(ivs, lbs, ubs, steps, {
scalarValue = load(tmpAlloc, accessInfo.tmpAccessExprs),
store(scalarValue, scalarMemRef, accessInfo.clippedScalarAccessExprs),
}),
tmpDealloc = dealloc(tmpAlloc.getLHS())});
tmpDealloc = dealloc(tmpAlloc)}));
// clang-format on
// Emit the MLIR.

View File

@ -1,29 +1,28 @@
// RUN: mlir-opt -lower-edsc-test %s | FileCheck %s
func @t1(%lhs: memref<3x4x5x6xf32>, %rhs: memref<3x4x5x6xf32>, %result: memref<3x4x5x6xf32>) -> () { return }
func @t1(%lhs: memref<3x4x5x6xvector<4xi8>>, %rhs: memref<3x4x5x6xvector<4xi8>>, %result: memref<3x4x5x6xvector<4xi8>>) -> () { return }
func @t2(%lhs: memref<3x4xf32>, %rhs: memref<3x4xf32>, %result: memref<3x4xf32>) -> () { return }
func @t3(%lhs: memref<f32>, %rhs: memref<f32>, %result: memref<f32>) -> () { return }
func @fn() {
"print"() {op: "x.add", fn: @t1: (memref<3x4x5x6xf32>, memref<3x4x5x6xf32>, memref<3x4x5x6xf32>) -> ()} : () -> ()
"print"() {op: "x.add", fn: @t1: (memref<3x4x5x6xvector<4xi8>>, memref<3x4x5x6xvector<4xi8>>, memref<3x4x5x6xvector<4xi8>>) -> ()} : () -> ()
"print"() {op: "x.add", fn: @t2: (memref<3x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()} : () -> ()
"print"() {op: "x.add", fn: @t3: (memref<f32>, memref<f32>, memref<f32>) -> ()} : () -> ()
return
}
// CHECK: block {
// CHECK: for({{.*}}=[[zero1:.*]] to {{.*}} step [[step1:.*]]) {
// CHECK: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
// CHECK: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
// CHECK: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
// CHECK: {{.*}} = store( ... );
// CHECK: };
// CHECK: };
// CHECK: };
// CHECK: }
// CHECK: }
// CHECK: block {
// CHECK-NEXT: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
// CHECK-NEXT: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
// CHECK-NEXT: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
// CHECK-NEXT: {{.*}} = store((load($3[{{.*}}, {{.*}}, {{.*}}, {{.*}}]) + load($4[{{.*}}, {{.*}}, {{.*}}, {{.*}}])), $5[{{.*}}, {{.*}}, {{.*}}, {{.*}}])
// CHECK-NEXT: };
// CHECK-NEXT: };
// CHECK-NEXT: };
// CHECK-NEXT: }
// CHECK: for({{.*}}=[[zero2:.*]] to {{.*}} step [[step2:.*]]) {
// CHECK: for({{.*}}=[[zero2]] to {{.*}} step [[step2]]) {
// CHECK: {{.*}} = store( ... );
// CHECK: };
// CHECK: }
// CHECK: }
// CHECK-NEXT: for({{.*}}=[[zero2]] to {{.*}} step [[step2]]) {
// CHECK-NEXT: {{.*}} = store((load($3[{{.*}}, {{.*}}]) + load($4[{{.*}}, {{.*}}])), $5[{{.*}}, {{.*}}])
// CHECK-NEXT: };
// CHECK-NEXT: }
// CHECK: {{.*}} = store((load($3[]) + load($4[])), $5[])

View File

@ -6,23 +6,25 @@ include "mlir/IR/op_base.td"
#endif // OP_BASE
def X_AddOp : Op<"x.add">,
Arguments<(ins Tensor:$lhs, Tensor:$rhs)>,
Results<(outs Tensor)> {
Arguments<(ins Tensor:$A, Tensor:$B)>,
Results<(outs Tensor: $C)> {
// TODO: extract referenceImplementation to Op.
// TODO: shrink the reference implementation
code referenceImplementation = [{
auto ivs = makeBindables(lhsShape.size());
Bindable zero, one;
// Same bindable, all equal to `zero`.
SmallVector<Bindable, 8> zeros(ivs.size(), zero);
// Same bindable, all equal to `one`.
SmallVector<Bindable, 8> ones(ivs.size(), one);
Indexed IA(lhs), IB(rhs), IC(result);
block = edsc::Block({
For(ivs, zeros, lhsShape, ones, {
IC[ivs] = IA[ivs] + IB[ivs]
})
});
auto ivs = makeNewExprs(view_A.rank());
// TODO(jpienaar@): automate the positional/named extraction. Need to be a
// bit careful about things memref (from which a "view" can be extracted)
// and the rest (see ReferenceImplGen.cpp).
Indexed A = args[0], B = args[1], C = args[2];
if (ivs.size() > 0) {
block.set(
For(ivs, view_A.lbs, view_A.ubs, view_A.steps, {
C(ivs) = A(ivs) + B(ivs)
}));
} else {
// 0-D case is always important to get right for composability.
block.set(C() = A() + B());
}
}];
}

View File

@ -51,42 +51,35 @@ static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
if (!ref)
continue;
os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n";
<< " edsc::ScopedEDSCContext raiiContext;\n"
<< " Stmt block;\n"
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n"
<< " auto zero = emitter.zero(); (void)zero;\n"
<< " auto one = emitter.one(); (void)one;\n"
<< " auto args = emitter.makeBoundFunctionArguments(f);\n"
// TODO(jpienaar): this is generally incorrect, not all args are memref
// in the general case.
<< " auto views = emitter.makeBoundMemRefViews(args.begin(), "
"args.end());\n";
// Create memrefs for the operands. Operand $x has variable name xMemRef.
for (auto arg : op.getOperands()) {
if (arg.name.empty())
PrintFatalError(def->getLoc(), "all operands must be named");
os << formatv(" mlir::BlockArgument* {0}MemRef;\n", arg.name);
}
os << " mlir::BlockArgument* resultMemRef;\n";
os << " {\n auto opIt = f->getArguments().begin();\n";
for (auto arg : op.getOperands()) {
os.indent(4) << arg.name << "MemRef = *opIt++;\n";
}
os.indent(4) << "resultMemRef = *opIt++;\n";
os << " }\n";
for (auto arg : op.getOperands()) {
os << formatv(" Bindable {0}; (void){0};\n", arg.name);
}
os << " Bindable result;\n";
for (auto arg : op.getOperands()) {
os.indent(2) << formatv(
"auto {0}Shape = emitter.makeBoundSizes({0}MemRef); "
"(void){0}Shape;\n",
arg.name);
for (auto en : llvm::enumerate(op.getOperands())) {
os.indent(2) << formatv("auto &view_{0} = views[{1}]; "
"(void)view_{0};\n",
en.value().name, en.index());
}
// Print the EDSC.
os << ref->getAsUnquotedString() << "\n}";
os << ref->getAsUnquotedString() << "\n";
os.indent(2) << "block.print(llvm::outs());\n\n";
os.indent(2) << "emitter.emitStmt(block);\n\n";
os.indent(2) << "llvm::outs() << \"\\n\";\n\n";
os << "}";
}
os << " else {"
<< " f->emitError(\"no reference implementation for \" + opName);\n"
<< " return;\n}\n";
os << " block.print(llvm::outs());\n llvm::outs() << \"\\n\";\n"
<< "}\n";
os << " else {\n";
os.indent(2) << "f->emitError(\"no reference impl. for \" + opName);\n";
os.indent(2) << "return;\n";
os << "}\n";
os << "}\n";
}
static mlir::GenRegistration