forked from OSchip/llvm-project
Cleanup EDSCs and start a functional auto-generated library of custom Ops
This CL applies the following simplifications to EDSCs: 1. Rename Block to StmtList because an MLIR Block is a different, not yet supported, notion; 2. Rework Bindable to drop specific storage and just use it as a simple wrapper around Expr. The only value of Bindable is to force a static cast when used by the user to bind into the emitter. For all intended purposes, Bindable is just a lightweight check that an Expr is Unbound. This simplifies usage and reduces the API footprint. After playing with it for some time, it wasn't worth the API cognition overhead; 3. Replace makeExprs and makeBindables by makeNewExprs and copyExprs which is more explicit and less easy to misuse; 4. Add generally useful functionality to MLIREmitter: a. expose zero and one for the ubiquitous common lower bounds and step; b. add support to create already bound Exprs for all function arguments as well as shapes and views for Exprs bound to memrefs. 5. Delete Stmt::operator= and replace by a `Stmt::set` method which is more explicit. 6. Make Stmt::operator Expr() explicit. 7. Indexed.indices assertions are removed to pave the way for expressing slices and views as well as to work with 0-D memrefs. The CL plugs those simplifications with TableGen and allows emitting a full MLIR function for pointwise add. This "x.add" op is both type and rank-agnostic (by allowing ArrayRef of Expr passed to For loops) and opens the door to spinning up a composable library of existing and custom ops that should automate a lot of the tedious work in TF/XLA -> MLIR. Testing needs to be significantly improved but can be done in a separate CL. PiperOrigin-RevId: 231982325
This commit is contained in:
parent
9f22a2391b
commit
0353ef99eb
|
@ -345,7 +345,7 @@ PYBIND11_MODULE(pybind, m) {
|
|||
|
||||
m.def("Block", [](const py::list &stmts) {
|
||||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
return PythonStmt(::Block(makeCStmts(owning, stmts)));
|
||||
return PythonStmt(::StmtList(makeCStmts(owning, stmts)));
|
||||
});
|
||||
m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs,
|
||||
const py::list &steps, const py::list &stmts) {
|
||||
|
|
|
@ -8,9 +8,6 @@ import unittest
|
|||
|
||||
import google_mlir.bindings.python.pybind as E
|
||||
|
||||
help(E)
|
||||
|
||||
|
||||
class EdscTest(unittest.TestCase):
|
||||
|
||||
def testBindables(self):
|
||||
|
@ -31,7 +28,7 @@ class EdscTest(unittest.TestCase):
|
|||
loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
|
||||
str = loop.__str__()
|
||||
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
||||
self.assertIn("$5 = ($2 + $3)", str)
|
||||
self.assertIn(" = ($2 + $3)", str)
|
||||
|
||||
def testTwoLoops(self):
|
||||
with E.ContextManager():
|
||||
|
@ -66,7 +63,7 @@ class EdscTest(unittest.TestCase):
|
|||
A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)]))
|
||||
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
|
||||
str = stmt.__str__()
|
||||
self.assertIn(" = store( ... )", str)
|
||||
self.assertIn(" = store(", str)
|
||||
|
||||
def testMatmul(self):
|
||||
with E.ContextManager():
|
||||
|
@ -84,7 +81,7 @@ class EdscTest(unittest.TestCase):
|
|||
self.assertIn("for($1 = $4 to $7 step $10) {", str)
|
||||
self.assertIn("for($2 = $5 to $8 step $11) {", str)
|
||||
self.assertIn("for($3 = $6 to $9 step $12) {", str)
|
||||
self.assertIn(" = store( ... )", str)
|
||||
self.assertIn(" = store", str)
|
||||
|
||||
def testArithmetic(self):
|
||||
with E.ContextManager():
|
||||
|
@ -105,9 +102,9 @@ class EdscTest(unittest.TestCase):
|
|||
i, j = list(map(E.Expr, [E.Bindable() for _ in range(2)]))
|
||||
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
|
||||
str = stmt.__str__()
|
||||
self.assertIn("block {", str)
|
||||
self.assertIn("$3 = ($1 + $2)", str)
|
||||
self.assertIn("$4 = ($1 - $2)", str)
|
||||
self.assertIn("stmt_list {", str)
|
||||
self.assertIn(" = ($1 + $2)", str)
|
||||
self.assertIn(" = ($1 - $2)", str)
|
||||
self.assertIn("}", str)
|
||||
|
||||
def testMLIRScalarTypes(self):
|
||||
|
|
|
@ -218,7 +218,7 @@ edsc_stmt_t Return(edsc_expr_list_t values);
|
|||
/// this is pure syntactic sugar to allow lists of mlir::edsc::Stmt to be
|
||||
/// specified and emitted. In particular, block arguments are not currently
|
||||
/// supported.
|
||||
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts);
|
||||
edsc_stmt_t StmtList(edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque statement for an mlir::ForInst with `enclosedStmts` nested
|
||||
/// below it.
|
||||
|
|
|
@ -57,8 +57,7 @@ namespace edsc {
|
|||
struct MLIREmitter {
|
||||
using BindingMap = llvm::DenseMap<Expr, Value *>;
|
||||
|
||||
explicit MLIREmitter(FuncBuilder *builder, Location location)
|
||||
: builder(builder), location(location) {}
|
||||
MLIREmitter(FuncBuilder *builder, Location location);
|
||||
|
||||
FuncBuilder *getBuilder() { return builder; }
|
||||
Location getLocation() { return location; }
|
||||
|
@ -99,12 +98,6 @@ struct MLIREmitter {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Emits the MLIR for `expr` and inserts at the `builder`'s insertion point.
|
||||
/// This function must only be called once on a given emitter.
|
||||
/// Prerequisites: all the Bindables have been bound.
|
||||
Value *emit(Expr expr);
|
||||
llvm::SmallVector<Value *, 8> emit(llvm::ArrayRef<Expr> exprs);
|
||||
|
||||
/// Emits the MLIR for `stmt` and inserts at the `builder`'s insertion point.
|
||||
/// Prerequisites: all the Bindables have been bound.
|
||||
void emitStmt(const Stmt &stmt);
|
||||
|
@ -114,6 +107,15 @@ struct MLIREmitter {
|
|||
/// Prerequisite: it must exist.
|
||||
Value *getValue(Expr expr) { return ssaBindings.lookup(expr); }
|
||||
|
||||
/// Returns unique zero and one Expr that are bound to the corresponding
|
||||
/// constant index value.
|
||||
Expr zero() { return zeroIndex; }
|
||||
Expr one() { return oneIndex; }
|
||||
|
||||
/// Returns a list of Expr that are bound to the function arguments.
|
||||
SmallVector<edsc::Expr, 8>
|
||||
makeBoundFunctionArguments(mlir::Function *function);
|
||||
|
||||
/// Returns a list of `Bindable` that are bound to the dimensions of the
|
||||
/// memRef. The proper DimOp and ConstantOp are constructed at the current
|
||||
/// insertion point in `builder`. They can be later hoisted and simplified in
|
||||
|
@ -121,12 +123,55 @@ struct MLIREmitter {
|
|||
///
|
||||
/// Prerequisite:
|
||||
/// `memRef` is a Value of type MemRefType.
|
||||
SmallVector<edsc::Expr, 8> makeBoundSizes(Value *memRef);
|
||||
SmallVector<edsc::Expr, 8> makeBoundMemRefShape(Value *memRef);
|
||||
|
||||
/// A BoundMemRefView represents the information required to step through a
|
||||
/// memref. It has placeholders for non-contiguous tensors that fit within the
|
||||
/// Fortran subarray model.
|
||||
/// It is extracted from a memref
|
||||
struct BoundMemRefView {
|
||||
SmallVector<edsc::Expr, 8> lbs;
|
||||
SmallVector<edsc::Expr, 8> ubs;
|
||||
SmallVector<edsc::Expr, 8> steps;
|
||||
edsc::Expr dim(unsigned idx) const { return ubs[idx]; }
|
||||
unsigned rank() const { return lbs.size(); }
|
||||
};
|
||||
BoundMemRefView makeBoundMemRefView(Value *memRef);
|
||||
BoundMemRefView makeBoundMemRefView(Expr memRef);
|
||||
template <typename IterType>
|
||||
SmallVector<BoundMemRefView, 8> makeBoundMemRefViews(IterType begin,
|
||||
IterType end) {
|
||||
static_assert(!std::is_same<decltype(*begin), Indexed>(),
|
||||
"Must use Bindable or Expr");
|
||||
SmallVector<mlir::edsc::MLIREmitter::BoundMemRefView, 8> res;
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
res.push_back(makeBoundMemRefView(*it));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
template <typename T>
|
||||
SmallVector<BoundMemRefView, 8>
|
||||
makeBoundMemRefViews(std::initializer_list<T> args) {
|
||||
static_assert(!std::is_same<T, Indexed>(), "Must use Bindable or Expr");
|
||||
SmallVector<mlir::edsc::MLIREmitter::BoundMemRefView, 8> res;
|
||||
for (auto m : args) {
|
||||
res.push_back(makeBoundMemRefView(m));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Emits the MLIR for `expr` and inserts at the `builder`'s insertion point.
|
||||
/// This function must only be called once on a given emitter.
|
||||
/// Prerequisites: all the Bindables have been bound.
|
||||
Value *emitExpr(Expr expr);
|
||||
llvm::SmallVector<Value *, 8> emitExprs(llvm::ArrayRef<Expr> exprs);
|
||||
|
||||
FuncBuilder *builder;
|
||||
Location location;
|
||||
BindingMap ssaBindings;
|
||||
// These are so ubiquitous that we make them bound and available to all.
|
||||
Expr zeroIndex, oneIndex;
|
||||
};
|
||||
|
||||
} // namespace edsc
|
||||
|
|
|
@ -41,7 +41,6 @@ namespace edsc {
|
|||
namespace detail {
|
||||
|
||||
struct ExprStorage;
|
||||
struct BindableStorage;
|
||||
struct UnaryExprStorage;
|
||||
struct BinaryExprStorage;
|
||||
struct TernaryExprStorage;
|
||||
|
@ -112,7 +111,7 @@ enum class ExprKind {
|
|||
Return,
|
||||
LAST_VARIADIC_EXPR = Return,
|
||||
FIRST_STMT_BLOCK_LIKE_EXPR = 600,
|
||||
Block = FIRST_STMT_BLOCK_LIKE_EXPR,
|
||||
StmtList = FIRST_STMT_BLOCK_LIKE_EXPR,
|
||||
For,
|
||||
LAST_STMT_BLOCK_LIKE_EXPR = For,
|
||||
LAST_NON_BINDABLE_EXPR = LAST_STMT_BLOCK_LIKE_EXPR,
|
||||
|
@ -170,17 +169,14 @@ public:
|
|||
return allocator;
|
||||
}
|
||||
|
||||
Expr() : storage(nullptr) {}
|
||||
Expr();
|
||||
/* implicit */ Expr(ImplType *storage) : storage(storage) {}
|
||||
explicit Expr(edsc_expr_t expr)
|
||||
: storage(reinterpret_cast<ImplType *>(expr)) {}
|
||||
operator edsc_expr_t() { return edsc_expr_t{storage}; }
|
||||
|
||||
Expr(const Expr &other) : storage(other.storage) {}
|
||||
Expr &operator=(Expr other) {
|
||||
storage = other.storage;
|
||||
return *this;
|
||||
}
|
||||
Expr(const Expr &other) = default;
|
||||
Expr &operator=(const Expr &other) = default;
|
||||
|
||||
explicit operator bool() { return storage; }
|
||||
bool operator!() { return storage == nullptr; }
|
||||
|
@ -193,6 +189,7 @@ public:
|
|||
|
||||
/// Returns the classification for this type.
|
||||
ExprKind getKind() const;
|
||||
unsigned getId() const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
@ -218,27 +215,26 @@ public:
|
|||
friend ::llvm::hash_code hash_value(Expr arg);
|
||||
|
||||
protected:
|
||||
friend struct detail::ExprStorage;
|
||||
ImplType *storage;
|
||||
|
||||
static void resetIds() { newId() = 0; }
|
||||
static unsigned &newId();
|
||||
};
|
||||
|
||||
struct Bindable : public Expr {
|
||||
using ImplType = detail::BindableStorage;
|
||||
friend class Expr;
|
||||
Bindable();
|
||||
unsigned getId() const;
|
||||
|
||||
// protected:
|
||||
Bindable(Expr::ImplType *ptr) : Expr(ptr) {
|
||||
assert(!ptr || isa<Bindable>() && "expected Bindable");
|
||||
Bindable() = delete;
|
||||
Bindable(Expr expr) : Expr(expr) {
|
||||
assert(expr.isa<Bindable>() && "expected Bindable");
|
||||
}
|
||||
Bindable(const Bindable &) = default;
|
||||
Bindable &operator=(const Bindable &) = default;
|
||||
explicit Bindable(const edsc_expr_t &expr) : Expr(expr) {}
|
||||
operator edsc_expr_t() { return edsc_expr_t{storage}; }
|
||||
|
||||
friend struct ScopedEDSCContext;
|
||||
|
||||
private:
|
||||
static void resetIds() { newId() = 0; }
|
||||
static unsigned &newId();
|
||||
friend class Expr;
|
||||
friend struct ScopedEDSCContext;
|
||||
};
|
||||
|
||||
struct UnaryExpr : public Expr {
|
||||
|
@ -301,7 +297,6 @@ struct StmtBlockLikeExpr : public Expr {
|
|||
StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
|
||||
llvm::ArrayRef<Type> types = {});
|
||||
llvm::ArrayRef<Expr> getExprs() const;
|
||||
llvm::ArrayRef<Type> getTypes() const;
|
||||
|
||||
protected:
|
||||
StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) {
|
||||
|
@ -316,7 +311,7 @@ protected:
|
|||
///
|
||||
/// ```mlir
|
||||
/// Stmt scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
|
||||
/// Stmt block = Block({
|
||||
/// Stmt block = StmtList({
|
||||
/// tmpAlloc = alloc(tmpMemRefType),
|
||||
/// vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
|
||||
/// For(ivs, lbs, ubs, steps, {
|
||||
|
@ -342,36 +337,38 @@ protected:
|
|||
/// 1. `For`-loops for which the `lhs` binds to the induction variable, `rhs`
|
||||
/// binds to an Expr of kind `ExprKind::For` with lower-bound, upper-bound and
|
||||
/// step respectively;
|
||||
/// 2. `Block` with an Expr of kind `ExprKind::Block` and which has no `rhs` but
|
||||
/// 2. `StmtList` with an Expr of kind `ExprKind::StmtList` and which has no
|
||||
/// `rhs` but
|
||||
/// only `enclosingStmts`.
|
||||
struct Stmt {
|
||||
using ImplType = detail::StmtStorage;
|
||||
friend class Expr;
|
||||
Stmt() : storage(nullptr) {}
|
||||
explicit Stmt(ImplType *storage) : storage(storage) {}
|
||||
Stmt(const Stmt &other) : storage(other.storage) {}
|
||||
Stmt operator=(const Stmt &other) {
|
||||
this->storage = other.storage; // NBD if &other == this
|
||||
return *this;
|
||||
}
|
||||
Stmt(const Stmt &other) = default;
|
||||
Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
|
||||
Stmt(const Bindable &lhs, const Expr &rhs,
|
||||
llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
|
||||
|
||||
explicit operator Expr() const { return getLHS(); }
|
||||
Stmt &operator=(const Expr &expr);
|
||||
Stmt &set(const Stmt &other) {
|
||||
this->storage = other.storage;
|
||||
return *this;
|
||||
}
|
||||
Stmt &operator=(const Stmt &other) = delete;
|
||||
explicit Stmt(edsc_stmt_t stmt)
|
||||
: storage(reinterpret_cast<ImplType *>(stmt)) {}
|
||||
operator edsc_stmt_t() { return edsc_stmt_t{storage}; }
|
||||
|
||||
operator Expr() const { return getLHS(); }
|
||||
|
||||
/// For debugging purposes.
|
||||
const void *getStoragePtr() const { return storage; }
|
||||
const ImplType *getStoragePtr() const { return storage; }
|
||||
|
||||
void print(raw_ostream &os, llvm::Twine indent = "") const;
|
||||
void dump() const;
|
||||
std::string str() const;
|
||||
|
||||
Bindable getLHS() const;
|
||||
Expr getLHS() const;
|
||||
Expr getRHS() const;
|
||||
llvm::ArrayRef<Stmt> getEnclosedStmts() const;
|
||||
|
||||
|
@ -470,45 +467,41 @@ namespace edsc {
|
|||
///
|
||||
/// Since bindings are hashed by the underlying pointer address, we need to be
|
||||
/// sure to construct new elements in a vector. We cannot just use
|
||||
/// `llvm::SmallVector<Bindable, 8> dims(n);` directly because a single
|
||||
/// `Bindable` will be default constructed and copied everywhere in the vector.
|
||||
/// Hilarity ensues when trying to bind structs that are already bound.
|
||||
llvm::SmallVector<Bindable, 8> makeBindables(unsigned n);
|
||||
llvm::SmallVector<Expr, 8> makeExprs(unsigned n);
|
||||
llvm::SmallVector<Expr, 8> makeExprs(ArrayRef<Bindable> bindables);
|
||||
/// `llvm::SmallVector<Expr, 8> dims(n);` directly because a single
|
||||
/// `Expr` will be default constructed and copied everywhere in the vector.
|
||||
/// Hilarity ensues when trying to bind `Expr` multiple times.
|
||||
llvm::SmallVector<Expr, 8> makeNewExprs(unsigned n);
|
||||
template <typename IterTy>
|
||||
llvm::SmallVector<Expr, 8> makeExprs(IterTy begin, IterTy end) {
|
||||
llvm::SmallVector<Expr, 8> copyExprs(IterTy begin, IterTy end) {
|
||||
return llvm::SmallVector<Expr, 8>(begin, end);
|
||||
}
|
||||
inline llvm::SmallVector<Expr, 8> copyExprs(llvm::ArrayRef<Expr> exprs) {
|
||||
return llvm::SmallVector<Expr, 8>(exprs.begin(), exprs.end());
|
||||
}
|
||||
|
||||
Expr alloc(llvm::ArrayRef<Expr> sizes, Type memrefType);
|
||||
inline Expr alloc(Type memrefType) { return alloc({}, memrefType); }
|
||||
Expr dealloc(Expr memref);
|
||||
Expr load(Expr m, Expr index);
|
||||
Expr load(Expr m, Bindable index);
|
||||
Expr load(Expr m, llvm::ArrayRef<Expr> indices);
|
||||
Expr load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices);
|
||||
Expr store(Expr val, Expr m, Expr index);
|
||||
Expr store(Expr val, Expr m, Bindable index);
|
||||
Expr store(Expr val, Expr m, llvm::ArrayRef<Expr> indices);
|
||||
Expr store(Expr val, Expr m, const llvm::SmallVectorImpl<Bindable> &indices);
|
||||
|
||||
Expr load(Expr m, llvm::ArrayRef<Expr> indices = {});
|
||||
inline Expr load(Stmt m, llvm::ArrayRef<Expr> indices = {}) {
|
||||
return load(m.getLHS(), indices);
|
||||
}
|
||||
Expr store(Expr val, Expr m, llvm::ArrayRef<Expr> indices = {});
|
||||
inline Expr store(Stmt val, Expr m, llvm::ArrayRef<Expr> indices = {}) {
|
||||
return store(val.getLHS(), m, indices);
|
||||
}
|
||||
Expr select(Expr cond, Expr lhs, Expr rhs);
|
||||
Expr vector_type_cast(Expr memrefExpr, Type memrefType);
|
||||
|
||||
Stmt Return(ArrayRef<Expr> values);
|
||||
Stmt Block(llvm::ArrayRef<Stmt> stmts);
|
||||
Stmt Return(ArrayRef<Expr> values = {});
|
||||
Stmt StmtList(llvm::ArrayRef<Stmt> stmts);
|
||||
Stmt For(Expr lb, Expr ub, Expr step, llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Expr> lbs,
|
||||
Stmt For(llvm::ArrayRef<Expr> indices, llvm::ArrayRef<Expr> lbs,
|
||||
llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Expr> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
|
||||
llvm::ArrayRef<Bindable> ubs, llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
|
||||
llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
|
||||
/// This helper class exists purely for sugaring purposes and allows writing
|
||||
/// expressions such as:
|
||||
|
@ -520,35 +513,30 @@ Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
|
|||
/// });
|
||||
/// ```
|
||||
struct Indexed {
|
||||
Indexed(Bindable b) : base(b), indices() {}
|
||||
Indexed(Expr e) : base(e), indices() {}
|
||||
|
||||
/// Returns a new `Indexed`. As a consequence, an Indexed with attached
|
||||
/// indices can never be reused unless it is captured (e.g. via a Stmt).
|
||||
/// This is consistent with SSA behavior in MLIR but also allows for some
|
||||
/// minimal state and sugaring.
|
||||
Indexed operator[](llvm::ArrayRef<Expr> indices) const;
|
||||
Indexed operator[](llvm::ArrayRef<Bindable> indices) const;
|
||||
Indexed operator()(llvm::ArrayRef<Expr> indices = {});
|
||||
|
||||
/// Returns a new `Stmt`.
|
||||
/// Emits a `store` and clears the attached indices.
|
||||
Stmt operator=(Expr expr); // NOLINT: unconventional-assing-operator
|
||||
|
||||
/// Implicit conversion.
|
||||
/// Emits a `load` and clears indices.
|
||||
operator Expr() const {
|
||||
assert(!indices.empty() && "Expected attached indices to Indexed");
|
||||
return load(base, indices);
|
||||
}
|
||||
/// Emits a `load`.
|
||||
operator Expr() { return load(base, indices); }
|
||||
|
||||
/// Operator overloadings.
|
||||
Expr operator+(Expr e) const { return static_cast<Expr>(*this) + e; }
|
||||
Expr operator-(Expr e) const { return static_cast<Expr>(*this) - e; }
|
||||
Expr operator*(Expr e) const { return static_cast<Expr>(*this) * e; }
|
||||
Expr operator+(Expr e) { return load(base, indices) + e; }
|
||||
Expr operator-(Expr e) { return load(base, indices) - e; }
|
||||
Expr operator*(Expr e) { return load(base, indices) * e; }
|
||||
|
||||
private:
|
||||
Expr base;
|
||||
llvm::SmallVector<Expr, 4> indices;
|
||||
llvm::SmallVector<Expr, 8> indices;
|
||||
};
|
||||
|
||||
} // namespace edsc
|
||||
|
|
|
@ -140,6 +140,13 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
|
|||
}
|
||||
}
|
||||
|
||||
mlir::edsc::MLIREmitter::MLIREmitter(FuncBuilder *builder, Location location)
|
||||
: builder(builder), location(location) {
|
||||
// Build the ubiquitous zero and one at the top of the function.
|
||||
bindConstant<ConstantIndexOp>(Bindable(zeroIndex), 0);
|
||||
bindConstant<ConstantIndexOp>(Bindable(oneIndex), 1);
|
||||
}
|
||||
|
||||
MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
|
||||
LLVM_DEBUG(printDefininingStatement(llvm::dbgs() << "\nBinding " << e << " @"
|
||||
<< e.getStoragePtr() << ": ",
|
||||
|
@ -153,7 +160,7 @@ MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
||||
Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
|
||||
auto it = ssaBindings.find(e);
|
||||
if (it != ssaBindings.end()) {
|
||||
return it->second;
|
||||
|
@ -163,12 +170,12 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
|||
Value *res = nullptr;
|
||||
if (auto un = e.dyn_cast<UnaryExpr>()) {
|
||||
if (un.getKind() == ExprKind::Dealloc) {
|
||||
builder->create<DeallocOp>(location, emit(un.getExpr()));
|
||||
builder->create<DeallocOp>(location, emitExpr(un.getExpr()));
|
||||
return nullptr;
|
||||
}
|
||||
} else if (auto bin = e.dyn_cast<BinaryExpr>()) {
|
||||
auto *a = emit(bin.getLHS());
|
||||
auto *b = emit(bin.getRHS());
|
||||
auto *a = emitExpr(bin.getLHS());
|
||||
auto *b = emitExpr(bin.getRHS());
|
||||
if (!a || !b) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -221,9 +228,9 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
|||
|
||||
if (auto ter = e.dyn_cast<TernaryExpr>()) {
|
||||
if (ter.getKind() == ExprKind::Select) {
|
||||
auto *cond = emit(ter.getCond());
|
||||
auto *lhs = emit(ter.getLHS());
|
||||
auto *rhs = emit(ter.getRHS());
|
||||
auto *cond = emitExpr(ter.getCond());
|
||||
auto *lhs = emitExpr(ter.getLHS());
|
||||
auto *rhs = emitExpr(ter.getRHS());
|
||||
if (!cond || !rhs || !lhs) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -233,7 +240,7 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
|||
|
||||
if (auto nar = e.dyn_cast<VariadicExpr>()) {
|
||||
if (nar.getKind() == ExprKind::Alloc) {
|
||||
auto exprs = emit(nar.getExprs());
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -243,26 +250,26 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
|||
builder->create<AllocOp>(location, types[0].cast<MemRefType>(), exprs)
|
||||
->getResult();
|
||||
} else if (nar.getKind() == ExprKind::Load) {
|
||||
auto exprs = emit(nar.getExprs());
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(exprs.size() > 1 && "Expected > 1 expr");
|
||||
assert(nar.getTypes().empty() && "Expected no type");
|
||||
assert(!exprs.empty() && "Load requires >= 1 exprs");
|
||||
assert(nar.getTypes().empty() && "Load expects no type");
|
||||
SmallVector<Value *, 8> vals(exprs.begin() + 1, exprs.end());
|
||||
res = builder->create<LoadOp>(location, exprs[0], vals)->getResult();
|
||||
} else if (nar.getKind() == ExprKind::Store) {
|
||||
auto exprs = emit(nar.getExprs());
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(exprs.size() > 2 && "Expected > 2 expr");
|
||||
assert(nar.getTypes().empty() && "Expected no type");
|
||||
assert(exprs.size() >= 2 && "Store requires >= 2 exprs");
|
||||
assert(nar.getTypes().empty() && "Store expects no type");
|
||||
SmallVector<Value *, 8> vals(exprs.begin() + 2, exprs.end());
|
||||
builder->create<StoreOp>(location, exprs[0], exprs[1], vals);
|
||||
return nullptr;
|
||||
} else if (nar.getKind() == ExprKind::VectorTypeCast) {
|
||||
auto exprs = emit(nar.getExprs());
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -274,7 +281,7 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
|||
types[0].cast<MemRefType>())
|
||||
->getResult();
|
||||
} else if (nar.getKind() == ExprKind::Return) {
|
||||
auto exprs = emit(nar.getExprs());
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
builder->create<ReturnOp>(location, exprs);
|
||||
return nullptr; // no Value* produced and this is fine.
|
||||
}
|
||||
|
@ -282,12 +289,11 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
|||
|
||||
if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) {
|
||||
if (expr.getKind() == ExprKind::For) {
|
||||
auto exprs = emit(expr.getExprs());
|
||||
auto exprs = emitExprs(expr.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(exprs.size() == 3 && "Expected 3 exprs");
|
||||
assert(expr.getTypes().empty() && "Expected no type");
|
||||
auto lb =
|
||||
exprs[0]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
|
||||
auto ub =
|
||||
|
@ -318,23 +324,24 @@ Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
|||
return res;
|
||||
}
|
||||
|
||||
SmallVector<Value *, 8> mlir::edsc::MLIREmitter::emit(ArrayRef<Expr> exprs) {
|
||||
return mlir::functional::map(
|
||||
[this](Expr e) {
|
||||
auto *res = this->emit(e);
|
||||
SmallVector<Value *, 8>
|
||||
mlir::edsc::MLIREmitter::emitExprs(ArrayRef<Expr> exprs) {
|
||||
SmallVector<Value *, 8> res;
|
||||
res.reserve(exprs.size());
|
||||
for (auto e : exprs) {
|
||||
res.push_back(this->emitExpr(e));
|
||||
LLVM_DEBUG(
|
||||
printDefininingStatement(llvm::dbgs() << "\nEmitted: ", *res));
|
||||
printDefininingStatement(llvm::dbgs() << "\nEmitted: ", *res.back()));
|
||||
}
|
||||
return res;
|
||||
},
|
||||
exprs);
|
||||
}
|
||||
|
||||
void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
|
||||
auto *block = builder->getBlock();
|
||||
auto ip = builder->getInsertionPoint();
|
||||
// Blocks are just a containing abstraction, they do not emit their RHS.
|
||||
if (stmt.getRHS().getKind() != ExprKind::Block) {
|
||||
auto *val = emit(stmt.getRHS());
|
||||
if (stmt.getRHS().getKind() != ExprKind::StmtList) {
|
||||
auto *val = emitExpr(stmt.getRHS());
|
||||
if (!val) {
|
||||
assert((stmt.getRHS().getKind() == ExprKind::Dealloc ||
|
||||
stmt.getRHS().getKind() == ExprKind::Store ||
|
||||
|
@ -342,7 +349,8 @@ void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
|
|||
"dealloc, store or return expected as the only 0-result ops");
|
||||
return;
|
||||
}
|
||||
bind(stmt.getLHS(), val);
|
||||
// Force create a bindable from stmt.lhs and bind it.
|
||||
bind(Bindable(stmt.getLHS()), val);
|
||||
if (stmt.getRHS().getKind() == ExprKind::For) {
|
||||
// Step into the loop.
|
||||
builder->setInsertionPointToStart(
|
||||
|
@ -408,10 +416,23 @@ static SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
|
|||
}
|
||||
|
||||
SmallVector<edsc::Expr, 8>
|
||||
mlir::edsc::MLIREmitter::makeBoundSizes(Value *memRef) {
|
||||
mlir::edsc::MLIREmitter::makeBoundFunctionArguments(mlir::Function *function) {
|
||||
SmallVector<edsc::Expr, 8> res;
|
||||
for (unsigned pos = 0, npos = function->getNumArguments(); pos < npos;
|
||||
++pos) {
|
||||
auto *arg = function->getArgument(pos);
|
||||
Expr b;
|
||||
bind(Bindable(b), arg);
|
||||
res.push_back(Expr(b));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<edsc::Expr, 8>
|
||||
mlir::edsc::MLIREmitter::makeBoundMemRefShape(Value *memRef) {
|
||||
assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
|
||||
MemRefType memRefType = memRef->getType().cast<MemRefType>();
|
||||
auto memRefSizes = edsc::makeBindables(memRefType.getShape().size());
|
||||
auto memRefSizes = edsc::makeNewExprs(memRefType.getShape().size());
|
||||
auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef);
|
||||
assert(memrefSizeValues.size() == memRefSizes.size());
|
||||
bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
|
||||
|
@ -419,37 +440,71 @@ mlir::edsc::MLIREmitter::makeBoundSizes(Value *memRef) {
|
|||
return res;
|
||||
}
|
||||
|
||||
mlir::edsc::MLIREmitter::BoundMemRefView
|
||||
mlir::edsc::MLIREmitter::makeBoundMemRefView(Value *memRef) {
|
||||
auto memRefType = memRef->getType().cast<mlir::MemRefType>();
|
||||
auto rank = memRefType.getRank();
|
||||
|
||||
SmallVector<edsc::Expr, 8> lbs;
|
||||
lbs.reserve(rank);
|
||||
Expr zero;
|
||||
bindConstant<mlir::ConstantIndexOp>(Bindable(zero), 0);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
lbs.push_back(zero);
|
||||
}
|
||||
|
||||
auto ubs = makeBoundMemRefShape(memRef);
|
||||
|
||||
SmallVector<edsc::Expr, 8> steps;
|
||||
lbs.reserve(rank);
|
||||
Expr one;
|
||||
bindConstant<mlir::ConstantIndexOp>(Bindable(one), 1);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
steps.push_back(one);
|
||||
}
|
||||
|
||||
return BoundMemRefView{lbs, ubs, steps};
|
||||
}
|
||||
|
||||
mlir::edsc::MLIREmitter::BoundMemRefView
|
||||
mlir::edsc::MLIREmitter::makeBoundMemRefView(Expr boundMemRef) {
|
||||
auto *v = getValue(mlir::edsc::Expr(boundMemRef));
|
||||
assert(v && "Expected a bound Expr");
|
||||
return makeBoundMemRefView(v);
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
|
||||
Expr b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
|
||||
e->getBuilder()->getBF16Type());
|
||||
return b;
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
Expr b;
|
||||
bool unused;
|
||||
mlir::APFloat val(value);
|
||||
val.convert(e->getBuilder()->getF16Type().getFloatSemantics(),
|
||||
mlir::APFloat::rmNearestTiesToEven, &unused);
|
||||
e->bindConstant<mlir::ConstantFloatOp>(b, val, e->getBuilder()->getF16Type());
|
||||
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), val,
|
||||
e->getBuilder()->getF16Type());
|
||||
return b;
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
|
||||
Expr b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
|
||||
e->getBuilder()->getF32Type());
|
||||
return b;
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
|
||||
Expr b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
|
||||
e->getBuilder()->getF64Type());
|
||||
return b;
|
||||
}
|
||||
|
@ -457,7 +512,7 @@ edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
|
|||
edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
|
||||
unsigned bitwidth) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
Expr b;
|
||||
e->bindConstant<mlir::ConstantIntOp>(
|
||||
b, value, e->getBuilder()->getIntegerType(bitwidth));
|
||||
return b;
|
||||
|
@ -465,8 +520,8 @@ edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
|
|||
|
||||
edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantIndexOp>(b, value);
|
||||
Expr b;
|
||||
e->bindConstant<mlir::ConstantIndexOp>(Bindable(b), value);
|
||||
return b;
|
||||
}
|
||||
|
||||
|
@ -493,8 +548,8 @@ edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
|
|||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
assert(pos < f->getNumArguments());
|
||||
auto *arg = *(f->getArguments().begin() + pos);
|
||||
Bindable b;
|
||||
e->bind(b, arg);
|
||||
Expr b;
|
||||
e->bind(Bindable(b), arg);
|
||||
return Expr(b);
|
||||
}
|
||||
|
||||
|
@ -505,8 +560,8 @@ void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
|
|||
assert(result->n == f->getNumArguments());
|
||||
for (unsigned pos = 0; pos < result->n; ++pos) {
|
||||
auto *arg = *(f->getArguments().begin() + pos);
|
||||
Bindable b;
|
||||
e->bind(b, arg);
|
||||
Expr b;
|
||||
e->bind(Bindable(b), arg);
|
||||
result->exprs[pos] = Expr(b);
|
||||
}
|
||||
}
|
||||
|
@ -515,6 +570,7 @@ unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter,
|
|||
edsc_expr_t boundMemRef) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
|
||||
assert(v && "Expected a bound Expr");
|
||||
auto memRefType = v->getType().cast<mlir::MemRefType>();
|
||||
return memRefType.getRank();
|
||||
}
|
||||
|
@ -523,10 +579,11 @@ void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
|||
edsc_expr_list_t *result) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
|
||||
assert(v && "Expected a bound Expr");
|
||||
auto memRefType = v->getType().cast<mlir::MemRefType>();
|
||||
auto rank = memRefType.getRank();
|
||||
assert(result->n == rank && "Unexpected memref shape binding results count");
|
||||
auto bindables = e->makeBoundSizes(v);
|
||||
auto bindables = e->makeBoundMemRefShape(v);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
result->exprs[i] = bindables[i];
|
||||
}
|
||||
|
@ -542,14 +599,14 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
|||
assert(resultLbs->n == rank && "Unexpected memref binding results count");
|
||||
assert(resultUbs->n == rank && "Unexpected memref binding results count");
|
||||
assert(resultSteps->n == rank && "Unexpected memref binding results count");
|
||||
auto bindables = e->makeBoundSizes(v);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
Bindable zero;
|
||||
auto bindables = e->makeBoundMemRefShape(v);
|
||||
Expr zero;
|
||||
e->bindConstant<mlir::ConstantIndexOp>(zero, 0);
|
||||
Expr one;
|
||||
e->bindConstant<mlir::ConstantIndexOp>(one, 1);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
resultLbs->exprs[i] = zero;
|
||||
resultUbs->exprs[i] = bindables[i];
|
||||
Bindable one;
|
||||
e->bindConstant<mlir::ConstantIndexOp>(one, 1);
|
||||
resultSteps->exprs[i] = one;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,12 +39,9 @@ namespace edsc {
|
|||
namespace detail {
|
||||
|
||||
struct ExprStorage {
|
||||
ExprStorage(ExprKind kind) : kind(kind) {}
|
||||
ExprStorage(ExprKind kind, unsigned id = Expr::newId())
|
||||
: kind(kind), id(id) {}
|
||||
ExprKind kind;
|
||||
};
|
||||
|
||||
struct BindableStorage : public ExprStorage {
|
||||
BindableStorage(unsigned id) : ExprStorage(ExprKind::Unbound), id(id) {}
|
||||
unsigned id;
|
||||
};
|
||||
|
||||
|
@ -94,8 +91,23 @@ mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
|
|||
Expr::globalAllocator() = nullptr;
|
||||
}
|
||||
|
||||
mlir::edsc::Expr::Expr() {
|
||||
// Initialize with placement new.
|
||||
storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
|
||||
new (storage) detail::ExprStorage(ExprKind::Unbound);
|
||||
}
|
||||
|
||||
ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; }
|
||||
|
||||
unsigned mlir::edsc::Expr::getId() const {
|
||||
return static_cast<ImplType *>(storage)->id;
|
||||
}
|
||||
|
||||
unsigned &mlir::edsc::Expr::newId() {
|
||||
static thread_local unsigned id = 0;
|
||||
return ++id;
|
||||
}
|
||||
|
||||
Expr mlir::edsc::Expr::operator+(Expr other) const {
|
||||
return BinaryExpr(ExprKind::Add, *this, other);
|
||||
}
|
||||
|
@ -132,25 +144,7 @@ Expr mlir::edsc::Expr::operator||(Expr other) const {
|
|||
}
|
||||
|
||||
// Free functions.
|
||||
llvm::SmallVector<Bindable, 8> mlir::edsc::makeBindables(unsigned n) {
|
||||
llvm::SmallVector<Bindable, 8> res;
|
||||
res.reserve(n);
|
||||
for (auto i = 0; i < n; ++i) {
|
||||
res.push_back(Bindable());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static llvm::SmallVector<Bindable, 8> makeBindables(edsc_expr_list_t exprList) {
|
||||
llvm::SmallVector<Bindable, 8> exprs;
|
||||
exprs.reserve(exprList.n);
|
||||
for (unsigned i = 0; i < exprList.n; ++i) {
|
||||
exprs.push_back(Expr(exprList.exprs[i]).cast<Bindable>());
|
||||
}
|
||||
return exprs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(unsigned n) {
|
||||
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) {
|
||||
llvm::SmallVector<Expr, 8> res;
|
||||
res.reserve(n);
|
||||
for (auto i = 0; i < n; ++i) {
|
||||
|
@ -159,15 +153,6 @@ llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(unsigned n) {
|
|||
return res;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(ArrayRef<Bindable> bindables) {
|
||||
llvm::SmallVector<Expr, 8> res;
|
||||
res.reserve(bindables.size());
|
||||
for (auto b : bindables) {
|
||||
res.push_back(b);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
|
||||
llvm::SmallVector<Expr, 8> exprs;
|
||||
exprs.reserve(exprList.n);
|
||||
|
@ -177,25 +162,26 @@ static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
|
|||
return exprs;
|
||||
}
|
||||
|
||||
static llvm::SmallVector<Stmt, 8> makeStmts(edsc_stmt_list_t enclosedStmts) {
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
stmts.reserve(enclosedStmts.n);
|
||||
static void fillStmts(edsc_stmt_list_t enclosedStmts,
|
||||
llvm::SmallVector<Stmt, 8> *stmts) {
|
||||
stmts->reserve(enclosedStmts.n);
|
||||
for (unsigned i = 0; i < enclosedStmts.n; ++i) {
|
||||
stmts.push_back(Stmt(enclosedStmts.stmts[i]));
|
||||
stmts->push_back(Stmt(enclosedStmts.stmts[i]));
|
||||
}
|
||||
return stmts;
|
||||
}
|
||||
|
||||
Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
|
||||
return VariadicExpr(ExprKind::Alloc, sizes, memrefType);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::Block(ArrayRef<Stmt> stmts) {
|
||||
return Stmt(StmtBlockLikeExpr(ExprKind::Block, {}), stmts);
|
||||
Stmt mlir::edsc::StmtList(ArrayRef<Stmt> stmts) {
|
||||
return Stmt(StmtBlockLikeExpr(ExprKind::StmtList, {}), stmts);
|
||||
}
|
||||
|
||||
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts) {
|
||||
return Stmt(mlir::edsc::Block(makeStmts(enclosedStmts)));
|
||||
edsc_stmt_t StmtList(edsc_stmt_list_t enclosedStmts) {
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
fillStmts(enclosedStmts, &stmts);
|
||||
return Stmt(mlir::edsc::StmtList(stmts));
|
||||
}
|
||||
|
||||
Expr mlir::edsc::dealloc(Expr memref) {
|
||||
|
@ -203,8 +189,8 @@ Expr mlir::edsc::dealloc(Expr memref) {
|
|||
}
|
||||
|
||||
Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
|
||||
Bindable idx;
|
||||
return For(idx, lb, ub, step, stmts);
|
||||
Expr idx;
|
||||
return For(Bindable(idx), lb, ub, step, stmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
||||
|
@ -212,105 +198,68 @@ Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
|||
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
|
||||
Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
|
||||
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
|
||||
ArrayRef<Stmt> enclosedStmts) {
|
||||
assert(!indices.empty());
|
||||
assert(indices.size() == lbs.size());
|
||||
assert(indices.size() == ubs.size());
|
||||
assert(indices.size() == steps.size());
|
||||
Expr iv = indices.back();
|
||||
Stmt curStmt =
|
||||
For(indices.back(), lbs.back(), ubs.back(), steps.back(), enclosedStmts);
|
||||
For(Bindable(iv), lbs.back(), ubs.back(), steps.back(), enclosedStmts);
|
||||
for (int64_t i = indices.size() - 2; i >= 0; --i) {
|
||||
curStmt = For(indices[i], lbs[i], ubs[i], steps[i], {curStmt});
|
||||
Expr iiv = indices[i];
|
||||
curStmt.set(For(Bindable(iiv), lbs[i], ubs[i], steps[i],
|
||||
llvm::ArrayRef<Stmt>{&curStmt, 1}));
|
||||
}
|
||||
return curStmt;
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
|
||||
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Bindable> ubs,
|
||||
llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()},
|
||||
SmallVector<Expr, 8>{ubs.begin(), ubs.end()},
|
||||
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
|
||||
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Expr> ubs,
|
||||
llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()}, ubs,
|
||||
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
|
||||
}
|
||||
|
||||
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
|
||||
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
|
||||
return Stmt(For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step),
|
||||
makeStmts(enclosedStmts)));
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
fillStmts(enclosedStmts, &stmts);
|
||||
return Stmt(
|
||||
For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step), stmts));
|
||||
}
|
||||
|
||||
edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
|
||||
edsc_expr_list_t ubs, edsc_expr_list_t steps,
|
||||
edsc_stmt_list_t enclosedStmts) {
|
||||
auto bindables = makeBindables(ivs);
|
||||
return Stmt(For(bindables, makeExprs(lbs), makeExprs(ubs), makeExprs(steps),
|
||||
makeStmts(enclosedStmts)));
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
fillStmts(enclosedStmts, &stmts);
|
||||
return Stmt(For(makeExprs(ivs), makeExprs(lbs), makeExprs(ubs),
|
||||
makeExprs(steps), stmts));
|
||||
}
|
||||
|
||||
template <typename BindableOrExpr>
|
||||
static Expr loadBuilder(Expr m, ArrayRef<BindableOrExpr> indices) {
|
||||
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
|
||||
SmallVector<Expr, 8> exprs;
|
||||
exprs.push_back(m);
|
||||
exprs.append(indices.begin(), indices.end());
|
||||
return VariadicExpr(ExprKind::Load, exprs);
|
||||
}
|
||||
Expr mlir::edsc::load(Expr m, Expr index) {
|
||||
return loadBuilder<Expr>(m, {index});
|
||||
}
|
||||
Expr mlir::edsc::load(Expr m, Bindable index) {
|
||||
return loadBuilder<Bindable>(m, {index});
|
||||
}
|
||||
Expr mlir::edsc::load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
|
||||
return loadBuilder(m, ArrayRef<Bindable>{indices.begin(), indices.end()});
|
||||
}
|
||||
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
|
||||
return loadBuilder(m, indices);
|
||||
}
|
||||
|
||||
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) {
|
||||
Indexed i(Expr(indexed.base).cast<Bindable>());
|
||||
Expr res = i[makeExprs(indices)];
|
||||
auto exprs = makeExprs(indices);
|
||||
Expr res = i(exprs);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename BindableOrExpr>
|
||||
static Expr storeBuilder(Expr val, Expr m, ArrayRef<BindableOrExpr> indices) {
|
||||
Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
|
||||
SmallVector<Expr, 8> exprs;
|
||||
exprs.push_back(val);
|
||||
exprs.push_back(m);
|
||||
exprs.append(indices.begin(), indices.end());
|
||||
return VariadicExpr(ExprKind::Store, exprs);
|
||||
}
|
||||
Expr mlir::edsc::store(Expr val, Expr m, Expr index) {
|
||||
return storeBuilder<Expr>(val, m, {index});
|
||||
}
|
||||
Expr mlir::edsc::store(Expr val, Expr m, Bindable index) {
|
||||
return storeBuilder<Bindable>(val, m, {index});
|
||||
}
|
||||
Expr mlir::edsc::store(Expr val, Expr m,
|
||||
const llvm::SmallVectorImpl<Bindable> &indices) {
|
||||
return storeBuilder(val, m,
|
||||
ArrayRef<Bindable>{indices.begin(), indices.end()});
|
||||
}
|
||||
Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
|
||||
return storeBuilder(val, m, indices);
|
||||
}
|
||||
|
||||
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
|
||||
edsc_expr_list_t indices) {
|
||||
Indexed i(Expr(indexed.base).cast<Bindable>());
|
||||
Indexed loc = i[makeExprs(indices)];
|
||||
auto exprs = makeExprs(indices);
|
||||
Indexed loc = i(exprs);
|
||||
return Stmt(loc = Expr(value));
|
||||
}
|
||||
|
||||
|
@ -384,15 +333,20 @@ void mlir::edsc::Expr::print(raw_ostream &os) const {
|
|||
}
|
||||
}
|
||||
} else if (auto nar = this->dyn_cast<VariadicExpr>()) {
|
||||
auto exprs = nar.getExprs();
|
||||
switch (nar.getKind()) {
|
||||
case ExprKind::Load:
|
||||
os << "load( ... )";
|
||||
os << "load(" << exprs[0] << "[";
|
||||
interleaveComma(ArrayRef<Expr>(exprs.begin() + 1, exprs.size() - 1), os);
|
||||
os << "])";
|
||||
return;
|
||||
case ExprKind::Store:
|
||||
os << "store( ... )";
|
||||
os << "store(" << exprs[0] << ", " << exprs[1] << "[";
|
||||
interleaveComma(ArrayRef<Expr>(exprs.begin() + 2, exprs.size() - 2), os);
|
||||
os << "])";
|
||||
return;
|
||||
case ExprKind::Return:
|
||||
interleaveComma(nar.getExprs(), os);
|
||||
interleaveComma(exprs, os);
|
||||
return;
|
||||
default: {
|
||||
os << "unknown_variadic";
|
||||
|
@ -430,22 +384,7 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
|
|||
return os;
|
||||
}
|
||||
|
||||
mlir::edsc::Bindable::Bindable()
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::BindableStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::BindableStorage{Bindable::newId()};
|
||||
}
|
||||
|
||||
edsc_expr_t makeBindable() { return Bindable(); }
|
||||
|
||||
unsigned mlir::edsc::Bindable::getId() const {
|
||||
return static_cast<ImplType *>(storage)->id;
|
||||
}
|
||||
|
||||
unsigned &mlir::edsc::Bindable::newId() {
|
||||
static thread_local unsigned id = 0;
|
||||
return ++id;
|
||||
}
|
||||
edsc_expr_t makeBindable() { return Bindable(Expr()); }
|
||||
|
||||
mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::UnaryExprStorage>()) {
|
||||
|
@ -519,9 +458,6 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
|
|||
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
|
||||
return static_cast<ImplType *>(storage)->exprs;
|
||||
}
|
||||
ArrayRef<Type> mlir::edsc::StmtBlockLikeExpr::getTypes() const {
|
||||
return static_cast<ImplType *>(storage)->types;
|
||||
}
|
||||
|
||||
mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
|
@ -536,7 +472,7 @@ mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
|
|||
}
|
||||
|
||||
mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
|
||||
: Stmt(Bindable(), rhs, enclosedStmts) {}
|
||||
: Stmt(Bindable(Expr()), rhs, enclosedStmts) {}
|
||||
|
||||
edsc_stmt_t makeStmt(edsc_expr_t e) {
|
||||
assert(e && "unexpected empty expression");
|
||||
|
@ -544,12 +480,12 @@ edsc_stmt_t makeStmt(edsc_expr_t e) {
|
|||
}
|
||||
|
||||
Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
|
||||
Stmt res(Bindable(), expr, {});
|
||||
Stmt res(Bindable(Expr()), expr, {});
|
||||
std::swap(res.storage, this->storage);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Bindable mlir::edsc::Stmt::getLHS() const {
|
||||
Expr mlir::edsc::Stmt::getLHS() const {
|
||||
return static_cast<ImplType *>(storage)->lhs;
|
||||
}
|
||||
|
||||
|
@ -562,7 +498,10 @@ llvm::ArrayRef<Stmt> mlir::edsc::Stmt::getEnclosedStmts() const {
|
|||
}
|
||||
|
||||
void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const {
|
||||
assert(storage && "Unexpected null storage,stmt must be bound to print");
|
||||
if (!storage) {
|
||||
os << "null_storage";
|
||||
return;
|
||||
}
|
||||
auto lhs = getLHS();
|
||||
auto rhs = getRHS();
|
||||
|
||||
|
@ -580,8 +519,8 @@ void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const {
|
|||
}
|
||||
os << indent << "}";
|
||||
return;
|
||||
case ExprKind::Block:
|
||||
os << indent << "block {";
|
||||
case ExprKind::StmtList:
|
||||
os << indent << "stmt_list {";
|
||||
for (auto &s : getEnclosedStmts()) {
|
||||
os << "\n";
|
||||
s.print(os, indent + " ");
|
||||
|
@ -613,24 +552,15 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
|
|||
return os;
|
||||
}
|
||||
|
||||
Indexed mlir::edsc::Indexed::operator[](llvm::ArrayRef<Expr> indices) const {
|
||||
Indexed mlir::edsc::Indexed::operator()(llvm::ArrayRef<Expr> indices) {
|
||||
Indexed res(base);
|
||||
res.indices = llvm::SmallVector<Expr, 4>(indices.begin(), indices.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
Indexed mlir::edsc::Indexed::
|
||||
operator[](llvm::ArrayRef<Bindable> indices) const {
|
||||
return (*this)[llvm::ArrayRef<Expr>{indices.begin(), indices.end()}];
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE: unconventional-assign-operator
|
||||
Stmt mlir::edsc::Indexed::operator=(Expr expr) {
|
||||
assert(!indices.empty() && "Expected attached indices to Indexed");
|
||||
assert(base);
|
||||
Stmt stmt(store(expr, base, indices));
|
||||
indices.clear();
|
||||
return stmt;
|
||||
return Stmt(store(expr, base, indices));
|
||||
}
|
||||
|
||||
edsc_indexed_t makeIndexed(edsc_expr_t expr) {
|
||||
|
|
|
@ -68,7 +68,7 @@ namespace {
|
|||
/// local storage.
|
||||
struct VectorTransferAccessInfo {
|
||||
// `ivs` are bound for `For` Stmt at `For` Stmt construction time.
|
||||
llvm::SmallVector<edsc::Bindable, 8> ivs;
|
||||
llvm::SmallVector<edsc::Expr, 8> ivs;
|
||||
llvm::SmallVector<edsc::Expr, 8> lowerBoundsExprs;
|
||||
llvm::SmallVector<edsc::Expr, 8> upperBoundsExprs;
|
||||
llvm::SmallVector<edsc::Expr, 8> stepExprs;
|
||||
|
@ -107,15 +107,15 @@ private:
|
|||
/// buffer.
|
||||
MemRefType vectorMemRefType;
|
||||
|
||||
// EDSC `emitter` and Bindables that are pre-bound at construction time.
|
||||
// vectorSizes are bound to the actual constant sizes of vectorType.
|
||||
llvm::SmallVector<edsc::Bindable, 8> vectorSizes;
|
||||
// accesses are bound to transfer->getIndices()
|
||||
llvm::SmallVector<edsc::Bindable, 8> accesses;
|
||||
// `zero` and `one` are bound to locally scoped constants.
|
||||
// `scalarMemRef` is bound to `transfer->getMemRef()`.
|
||||
edsc::Bindable zero, one, scalarMemRef;
|
||||
// EDSC `emitter` and Exprs that are pre-bound at construction time.
|
||||
edsc::MLIREmitter emitter;
|
||||
// vectorSizes are bound to the actual constant sizes of vectorType.
|
||||
llvm::SmallVector<edsc::Expr, 8> vectorSizes;
|
||||
// accesses are bound to transfer->getIndices()
|
||||
llvm::SmallVector<edsc::Expr, 8> accesses;
|
||||
// `zero` and `one` are bound emitter.zero() and emitter.one().
|
||||
// `scalarMemRef` is bound to `transfer->getMemRef()`.
|
||||
edsc::Expr zero, one, scalarMemRef;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
@ -164,19 +164,19 @@ VectorTransferAccessInfo
|
|||
VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
|
||||
using namespace mlir::edsc;
|
||||
|
||||
// Create Bindable objects for ivs, they will be bound at `For` Stmt
|
||||
// Create new Exprs for ivs, they will be bound at `For` Stmt
|
||||
// construction.
|
||||
auto ivs = makeBindables(vectorShape.size());
|
||||
auto ivs = makeNewExprs(vectorShape.size());
|
||||
|
||||
// Create and bind Bindables to refer to the Value for memref sizes.
|
||||
auto memRefSizes = emitter.makeBoundSizes(transfer->getMemRef());
|
||||
// Create and bind Exprs to refer to the Value for memref sizes.
|
||||
auto memRefSizes = emitter.makeBoundMemRefShape(transfer->getMemRef());
|
||||
|
||||
// Create the edsc::Expr for the clipped and transposes access expressions
|
||||
// using the permutationMap. Additionally, capture the index accessing the
|
||||
// most minor dimension.
|
||||
int coalescingIndex = -1;
|
||||
auto clippedScalarAccessExprs = makeExprs(accesses);
|
||||
auto tmpAccessExprs = makeExprs(ivs);
|
||||
auto clippedScalarAccessExprs = copyExprs(accesses);
|
||||
auto tmpAccessExprs = copyExprs(ivs);
|
||||
llvm::DenseSet<unsigned> clipped;
|
||||
for (auto it : llvm::enumerate(permutationMap.getResults())) {
|
||||
if (auto affineExpr = it.value().template dyn_cast<AffineDimExpr>()) {
|
||||
|
@ -221,9 +221,9 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
|
|||
|
||||
// Create the proper bindables for lbs, ubs and steps. Additionally, if we
|
||||
// recorded a coalescing index, permute the loop informations.
|
||||
auto lbs = makeBindables(ivs.size());
|
||||
auto ubs = makeExprs(vectorSizes);
|
||||
auto steps = makeBindables(ivs.size());
|
||||
auto lbs = makeNewExprs(ivs.size());
|
||||
auto ubs = copyExprs(vectorSizes);
|
||||
auto steps = makeNewExprs(ivs.size());
|
||||
if (coalescingIndex >= 0) {
|
||||
std::swap(ivs[coalescingIndex], ivs.back());
|
||||
std::swap(lbs[coalescingIndex], lbs.back());
|
||||
|
@ -237,9 +237,9 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
|
|||
llvm::zip(steps, SmallVector<int64_t, 8>(ivs.size(), 1)));
|
||||
|
||||
return VectorTransferAccessInfo{ivs,
|
||||
makeExprs(lbs),
|
||||
copyExprs(lbs),
|
||||
ubs,
|
||||
makeExprs(steps),
|
||||
copyExprs(steps),
|
||||
clippedScalarAccessExprs,
|
||||
tmpAccessExprs};
|
||||
}
|
||||
|
@ -255,14 +255,13 @@ VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
|
|||
tmpMemRefType(
|
||||
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0)),
|
||||
vectorMemRefType(MemRefType::get({1}, vectorType, {}, 0)),
|
||||
vectorSizes(edsc::makeBindables(vectorShape.size())),
|
||||
emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())) {
|
||||
emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())),
|
||||
vectorSizes(edsc::makeNewExprs(vectorShape.size())), zero(emitter.zero()),
|
||||
one(emitter.one()) {
|
||||
// Bind the Bindable.
|
||||
SmallVector<Value *, 8> transferIndices(transfer->getIndices());
|
||||
accesses = edsc::makeBindables(transferIndices.size());
|
||||
emitter.bind(scalarMemRef, transfer->getMemRef())
|
||||
.template bindConstant<ConstantIndexOp>(zero, 0)
|
||||
.template bindConstant<ConstantIndexOp>(one, 1)
|
||||
accesses = edsc::makeNewExprs(transferIndices.size());
|
||||
emitter.bind(edsc::Bindable(scalarMemRef), transfer->getMemRef())
|
||||
.template bindZipRangeConstants<ConstantIndexOp>(
|
||||
llvm::zip(vectorSizes, vectorShape))
|
||||
.template bindZipRange(llvm::zip(accesses, transfer->getIndices()));
|
||||
|
@ -321,23 +320,24 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
|||
auto &lbs = accessInfo.lowerBoundsExprs;
|
||||
auto &ubs = accessInfo.upperBoundsExprs;
|
||||
auto &steps = accessInfo.stepExprs;
|
||||
Stmt scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
|
||||
Stmt block = edsc::Block({
|
||||
Expr scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
|
||||
Stmt block = edsc::StmtList({
|
||||
tmpAlloc = alloc(tmpMemRefType),
|
||||
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
|
||||
vectorView = vector_type_cast(Expr(tmpAlloc), vectorMemRefType),
|
||||
For(ivs, lbs, ubs, steps, {
|
||||
scalarValue = load(scalarMemRef, accessInfo.clippedScalarAccessExprs),
|
||||
store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs),
|
||||
}),
|
||||
vectorValue = load(vectorView, {zero}),
|
||||
tmpDealloc = dealloc(tmpAlloc.getLHS())});
|
||||
tmpDealloc = dealloc(tmpAlloc)
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
// Emit the MLIR.
|
||||
emitter.emitStmt(block);
|
||||
|
||||
// Finalize rewriting.
|
||||
transfer->replaceAllUsesWith(emitter.getValue(vectorValue.getLHS()));
|
||||
transfer->replaceAllUsesWith(emitter.getValue(vectorValue));
|
||||
transfer->erase();
|
||||
}
|
||||
|
||||
|
@ -367,24 +367,24 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
|
|||
auto accessInfo = makeVectorTransferAccessInfo();
|
||||
|
||||
// Bind vector value for the vector_transfer_write.
|
||||
Bindable vectorValue;
|
||||
emitter.bind(vectorValue, transfer->getVector());
|
||||
Expr vectorValue;
|
||||
emitter.bind(Bindable(vectorValue), transfer->getVector());
|
||||
|
||||
// clang-format off
|
||||
auto &ivs = accessInfo.ivs;
|
||||
auto &lbs = accessInfo.lowerBoundsExprs;
|
||||
auto &ubs = accessInfo.upperBoundsExprs;
|
||||
auto &steps = accessInfo.stepExprs;
|
||||
Stmt scalarValue, tmpAlloc, tmpDealloc, vectorView;
|
||||
Stmt block = edsc::Block({
|
||||
Expr scalarValue, tmpAlloc, tmpDealloc, vectorView;
|
||||
Stmt block(edsc::StmtList({
|
||||
tmpAlloc = alloc(tmpMemRefType),
|
||||
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
|
||||
store(vectorValue, vectorView, {zero}),
|
||||
store(vectorValue, vectorView, MutableArrayRef<Expr>{zero}),
|
||||
For(ivs, lbs, ubs, steps, {
|
||||
scalarValue = load(tmpAlloc, accessInfo.tmpAccessExprs),
|
||||
store(scalarValue, scalarMemRef, accessInfo.clippedScalarAccessExprs),
|
||||
}),
|
||||
tmpDealloc = dealloc(tmpAlloc.getLHS())});
|
||||
tmpDealloc = dealloc(tmpAlloc)}));
|
||||
// clang-format on
|
||||
|
||||
// Emit the MLIR.
|
||||
|
|
|
@ -1,29 +1,28 @@
|
|||
// RUN: mlir-opt -lower-edsc-test %s | FileCheck %s
|
||||
|
||||
func @t1(%lhs: memref<3x4x5x6xf32>, %rhs: memref<3x4x5x6xf32>, %result: memref<3x4x5x6xf32>) -> () { return }
|
||||
func @t1(%lhs: memref<3x4x5x6xvector<4xi8>>, %rhs: memref<3x4x5x6xvector<4xi8>>, %result: memref<3x4x5x6xvector<4xi8>>) -> () { return }
|
||||
func @t2(%lhs: memref<3x4xf32>, %rhs: memref<3x4xf32>, %result: memref<3x4xf32>) -> () { return }
|
||||
func @t3(%lhs: memref<f32>, %rhs: memref<f32>, %result: memref<f32>) -> () { return }
|
||||
|
||||
func @fn() {
|
||||
"print"() {op: "x.add", fn: @t1: (memref<3x4x5x6xf32>, memref<3x4x5x6xf32>, memref<3x4x5x6xf32>) -> ()} : () -> ()
|
||||
"print"() {op: "x.add", fn: @t1: (memref<3x4x5x6xvector<4xi8>>, memref<3x4x5x6xvector<4xi8>>, memref<3x4x5x6xvector<4xi8>>) -> ()} : () -> ()
|
||||
"print"() {op: "x.add", fn: @t2: (memref<3x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()} : () -> ()
|
||||
"print"() {op: "x.add", fn: @t3: (memref<f32>, memref<f32>, memref<f32>) -> ()} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: block {
|
||||
// CHECK: for({{.*}}=[[zero1:.*]] to {{.*}} step [[step1:.*]]) {
|
||||
// CHECK: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
|
||||
// CHECK: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
|
||||
// CHECK: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
|
||||
// CHECK: {{.*}} = store( ... );
|
||||
// CHECK: };
|
||||
// CHECK: };
|
||||
// CHECK: };
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: block {
|
||||
// CHECK-NEXT: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
|
||||
// CHECK-NEXT: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
|
||||
// CHECK-NEXT: for({{.*}}=[[zero1]] to {{.*}} step [[step1]]) {
|
||||
// CHECK-NEXT: {{.*}} = store((load($3[{{.*}}, {{.*}}, {{.*}}, {{.*}}]) + load($4[{{.*}}, {{.*}}, {{.*}}, {{.*}}])), $5[{{.*}}, {{.*}}, {{.*}}, {{.*}}])
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: for({{.*}}=[[zero2:.*]] to {{.*}} step [[step2:.*]]) {
|
||||
// CHECK: for({{.*}}=[[zero2]] to {{.*}} step [[step2]]) {
|
||||
// CHECK: {{.*}} = store( ... );
|
||||
// CHECK: };
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK-NEXT: for({{.*}}=[[zero2]] to {{.*}} step [[step2]]) {
|
||||
// CHECK-NEXT: {{.*}} = store((load($3[{{.*}}, {{.*}}]) + load($4[{{.*}}, {{.*}}])), $5[{{.*}}, {{.*}}])
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: {{.*}} = store((load($3[]) + load($4[])), $5[])
|
||||
|
|
|
@ -6,23 +6,25 @@ include "mlir/IR/op_base.td"
|
|||
#endif // OP_BASE
|
||||
|
||||
def X_AddOp : Op<"x.add">,
|
||||
Arguments<(ins Tensor:$lhs, Tensor:$rhs)>,
|
||||
Results<(outs Tensor)> {
|
||||
Arguments<(ins Tensor:$A, Tensor:$B)>,
|
||||
Results<(outs Tensor: $C)> {
|
||||
// TODO: extract referenceImplementation to Op.
|
||||
// TODO: shrink the reference implementation
|
||||
code referenceImplementation = [{
|
||||
auto ivs = makeBindables(lhsShape.size());
|
||||
Bindable zero, one;
|
||||
// Same bindable, all equal to `zero`.
|
||||
SmallVector<Bindable, 8> zeros(ivs.size(), zero);
|
||||
// Same bindable, all equal to `one`.
|
||||
SmallVector<Bindable, 8> ones(ivs.size(), one);
|
||||
Indexed IA(lhs), IB(rhs), IC(result);
|
||||
block = edsc::Block({
|
||||
For(ivs, zeros, lhsShape, ones, {
|
||||
IC[ivs] = IA[ivs] + IB[ivs]
|
||||
})
|
||||
});
|
||||
auto ivs = makeNewExprs(view_A.rank());
|
||||
// TODO(jpienaar@): automate the positional/named extraction. Need to be a
|
||||
// bit careful about things memref (from which a "view" can be extracted)
|
||||
// and the rest (see ReferenceImplGen.cpp).
|
||||
Indexed A = args[0], B = args[1], C = args[2];
|
||||
if (ivs.size() > 0) {
|
||||
block.set(
|
||||
For(ivs, view_A.lbs, view_A.ubs, view_A.steps, {
|
||||
C(ivs) = A(ivs) + B(ivs)
|
||||
}));
|
||||
} else {
|
||||
// 0-D case is always important to get right for composability.
|
||||
block.set(C() = A() + B());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -51,42 +51,35 @@ static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
|
|||
if (!ref)
|
||||
continue;
|
||||
os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
|
||||
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n";
|
||||
<< " edsc::ScopedEDSCContext raiiContext;\n"
|
||||
<< " Stmt block;\n"
|
||||
<< " edsc::MLIREmitter emitter(&builder, f->getLoc());\n"
|
||||
<< " auto zero = emitter.zero(); (void)zero;\n"
|
||||
<< " auto one = emitter.one(); (void)one;\n"
|
||||
<< " auto args = emitter.makeBoundFunctionArguments(f);\n"
|
||||
// TODO(jpienaar): this is generally incorrect, not all args are memref
|
||||
// in the general case.
|
||||
<< " auto views = emitter.makeBoundMemRefViews(args.begin(), "
|
||||
"args.end());\n";
|
||||
|
||||
// Create memrefs for the operands. Operand $x has variable name xMemRef.
|
||||
for (auto arg : op.getOperands()) {
|
||||
if (arg.name.empty())
|
||||
PrintFatalError(def->getLoc(), "all operands must be named");
|
||||
os << formatv(" mlir::BlockArgument* {0}MemRef;\n", arg.name);
|
||||
}
|
||||
os << " mlir::BlockArgument* resultMemRef;\n";
|
||||
os << " {\n auto opIt = f->getArguments().begin();\n";
|
||||
for (auto arg : op.getOperands()) {
|
||||
os.indent(4) << arg.name << "MemRef = *opIt++;\n";
|
||||
}
|
||||
os.indent(4) << "resultMemRef = *opIt++;\n";
|
||||
os << " }\n";
|
||||
|
||||
for (auto arg : op.getOperands()) {
|
||||
os << formatv(" Bindable {0}; (void){0};\n", arg.name);
|
||||
}
|
||||
os << " Bindable result;\n";
|
||||
|
||||
for (auto arg : op.getOperands()) {
|
||||
os.indent(2) << formatv(
|
||||
"auto {0}Shape = emitter.makeBoundSizes({0}MemRef); "
|
||||
"(void){0}Shape;\n",
|
||||
arg.name);
|
||||
for (auto en : llvm::enumerate(op.getOperands())) {
|
||||
os.indent(2) << formatv("auto &view_{0} = views[{1}]; "
|
||||
"(void)view_{0};\n",
|
||||
en.value().name, en.index());
|
||||
}
|
||||
|
||||
// Print the EDSC.
|
||||
os << ref->getAsUnquotedString() << "\n}";
|
||||
os << ref->getAsUnquotedString() << "\n";
|
||||
os.indent(2) << "block.print(llvm::outs());\n\n";
|
||||
os.indent(2) << "emitter.emitStmt(block);\n\n";
|
||||
os.indent(2) << "llvm::outs() << \"\\n\";\n\n";
|
||||
os << "}";
|
||||
}
|
||||
os << " else {"
|
||||
<< " f->emitError(\"no reference implementation for \" + opName);\n"
|
||||
<< " return;\n}\n";
|
||||
os << " block.print(llvm::outs());\n llvm::outs() << \"\\n\";\n"
|
||||
<< "}\n";
|
||||
os << " else {\n";
|
||||
os.indent(2) << "f->emitError(\"no reference impl. for \" + opName);\n";
|
||||
os.indent(2) << "return;\n";
|
||||
os << "}\n";
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
static mlir::GenRegistration
|
||||
|
|
Loading…
Reference in New Issue