forked from OSchip/llvm-project
EDSC: make Expr typed and extensible
Expose the result types of edsc::Expr, which are now stored for all types of Exprs and not only for the variadic ones. Require return types when an Expr is constructed, if it will ever have some. An empty return type list is interpreted as an Expr that does not create a value (e.g. `return` or `store`). Conceptually, all edss::Exprs are now typed, with the type being a (potentially empty) tuple of return types. Unbound expressions and Bindables must now be constructed with a specific type they will take. This makes EDSC less evidently type-polymorphic, but we can still write generic code such as Expr sumOfSquares(Expr lhs, Expr rhs) { return lhs * lhs + rhs * rhs; } and use it to construct different typed expressions as sumOfSquares(Bindable(IndexType::get(ctx)), Bindable(IndexType::get(ctx))); sumOfSquares(Bindable(FloatType::getF32(ctx)), Bindable(FloatType::getF32(ctx))); On the positive side, we get the following. 1. We can now perform type checking when constructing Exprs rather than during MLIR emission. Nevertheless, this is still duplicates the Op::verify() until we can factor out type checking from that. 2. MLIREmitter is significantly simplified. 3. ExprKind enum is only used for actual kinds of expressions. Data structures are converging with AbstractOperation, and the users can now create a VariadicExpr("canonical_op_name", {types}, {exprs}) for any operation, even an unregistered one without having to extend the enum and make pervasive changes to EDSCs. On the negative side, we get the following. 1. Typed bindables are more verbose, even in Python. 2. We lose the ability to do print debugging for higher-level EDSC abstractions that are implemented as multiple MLIR Ops, for example logical disjunction. This is the step 2/n towards making EDSC extensible. *** Move MLIR Op construction from MLIREmitter::emitExpr to Expr::build since Expr now has sufficient information to build itself. This is the step 3/n towards making EDSC extensible. Both of these strive to minimize the amount of irrelevant changes. In particular, this introduces more complex pretty-printing for affine and binary expression to make sure tests continue to pass. It also relies on string comparison to identify specific operations that an Expr produces. PiperOrigin-RevId: 234609882
This commit is contained in:
parent
e0fc503896
commit
b4dba895a6
|
@ -110,6 +110,9 @@ struct PythonMLIRModule {
|
|||
return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
|
||||
int64_list_t{sizes.data(), sizes.size()});
|
||||
}
|
||||
PythonType makeIndexType() {
|
||||
return ::makeIndexType(mlir_context_t{&mlirContext});
|
||||
}
|
||||
PythonFunction makeFunction(const std::string &name,
|
||||
std::vector<PythonType> &inputTypes,
|
||||
std::vector<PythonType> &outputTypes) {
|
||||
|
@ -177,7 +180,8 @@ struct PythonExpr {
|
|||
};
|
||||
|
||||
struct PythonBindable : public PythonExpr {
|
||||
PythonBindable() : PythonExpr(edsc_expr_t{makeBindable()}) {}
|
||||
explicit PythonBindable(const PythonType &type)
|
||||
: PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
|
||||
PythonBindable(PythonExpr expr) : PythonExpr(expr) {
|
||||
assert(Expr(expr).isa<Bindable>() && "Expected Bindable");
|
||||
}
|
||||
|
@ -213,7 +217,6 @@ struct PythonBlock {
|
|||
};
|
||||
|
||||
struct PythonIndexed : public edsc_indexed_t {
|
||||
PythonIndexed() : edsc_indexed_t{makeIndexed(PythonBindable())} {}
|
||||
PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {}
|
||||
PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {}
|
||||
operator PythonExpr() { return PythonExpr(base); }
|
||||
|
@ -475,6 +478,8 @@ PYBIND11_MODULE(pybind, m) {
|
|||
.def("make_memref_type", &PythonMLIRModule::makeMemRefType,
|
||||
"Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
|
||||
"denote symbolic dimensions in the resulting memref shape.")
|
||||
.def("make_index_type", &PythonMLIRModule::makeIndexType,
|
||||
"Returns an mlir::IndexType")
|
||||
.def("compile", &PythonMLIRModule::compile,
|
||||
"Compiles the mlir::Module to LLVMIR a creates new opaque "
|
||||
"ExecutionEngine backed by the ORC JIT.")
|
||||
|
@ -576,7 +581,7 @@ PYBIND11_MODULE(pybind, m) {
|
|||
m, "Bindable",
|
||||
"Wrapping class for mlir::edsc::Bindable.\nA Bindable is a special Expr "
|
||||
"that can be bound manually to specific MLIR SSA Values.")
|
||||
.def(py::init<>())
|
||||
.def(py::init<PythonType>())
|
||||
.def("__str__", &PythonBindable::str);
|
||||
|
||||
py::class_<PythonStmt>(m, "Stmt", "Wrapping class for mlir::edsc::Stmt.")
|
||||
|
@ -588,7 +593,6 @@ PYBIND11_MODULE(pybind, m) {
|
|||
m, "Indexed",
|
||||
"Wrapping class for mlir::edsc::Indexed.\nAn Indexed is a wrapper class "
|
||||
"that support load and store operations.")
|
||||
.def(py::init<>(), R"DOC(Build from fresh Bindable)DOC")
|
||||
.def(py::init<PythonExpr>(), R"DOC(Build from existing Expr)DOC")
|
||||
.def(py::init<PythonBindable>(), R"DOC(Build from existing Bindable)DOC")
|
||||
.def(
|
||||
|
|
|
@ -10,21 +10,30 @@ import google_mlir.bindings.python.pybind as E
|
|||
|
||||
class EdscTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.module = E.MLIRModule()
|
||||
self.boolType = self.module.make_scalar_type("i", 1)
|
||||
self.i32Type = self.module.make_scalar_type("i", 32)
|
||||
self.f32Type = self.module.make_scalar_type("f32")
|
||||
self.indexType = self.module.make_index_type()
|
||||
|
||||
def testBindables(self):
|
||||
with E.ContextManager():
|
||||
i = E.Expr(E.Bindable())
|
||||
i = E.Expr(E.Bindable(self.i32Type))
|
||||
self.assertIn("$1", i.__str__())
|
||||
|
||||
def testOneExpr(self):
|
||||
with E.ContextManager():
|
||||
i, lb, ub = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
||||
i, lb, ub = list(
|
||||
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
|
||||
expr = E.Mul(i, E.Add(lb, ub))
|
||||
str = expr.__str__()
|
||||
self.assertIn("($1 * ($2 + $3))", str)
|
||||
|
||||
def testOneLoop(self):
|
||||
with E.ContextManager():
|
||||
i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
i, lb, ub, step = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
|
||||
str = loop.__str__()
|
||||
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
||||
|
@ -32,7 +41,8 @@ class EdscTest(unittest.TestCase):
|
|||
|
||||
def testTwoLoops(self):
|
||||
with E.ContextManager():
|
||||
i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
i, lb, ub, step = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])])
|
||||
str = loop.__str__()
|
||||
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
||||
|
@ -41,11 +51,12 @@ class EdscTest(unittest.TestCase):
|
|||
|
||||
def testNestedLoops(self):
|
||||
with E.ContextManager():
|
||||
i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
ivs = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
lbs = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
steps = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
i, lb, ub, step = list(
|
||||
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(4)]))
|
||||
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
loop = E.For(ivs, lbs, ubs, steps, [
|
||||
E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]),
|
||||
])
|
||||
|
@ -59,20 +70,23 @@ class EdscTest(unittest.TestCase):
|
|||
|
||||
def testIndexed(self):
|
||||
with E.ContextManager():
|
||||
i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
||||
A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)]))
|
||||
i, j, k = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
|
||||
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
|
||||
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
|
||||
str = stmt.__str__()
|
||||
self.assertIn(" = store(", str)
|
||||
|
||||
def testMatmul(self):
|
||||
with E.ContextManager():
|
||||
ivs = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
||||
lbs = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
||||
steps = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
||||
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
i, j, k = ivs[0], ivs[1], ivs[2]
|
||||
A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)]))
|
||||
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
|
||||
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
|
||||
loop = E.For(
|
||||
ivs, lbs, ubs, steps,
|
||||
[C.store([i, j],
|
||||
|
@ -85,29 +99,36 @@ class EdscTest(unittest.TestCase):
|
|||
|
||||
def testArithmetic(self):
|
||||
with E.ContextManager():
|
||||
i, j, k, l = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
i, j, k, l = list(
|
||||
map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)]))
|
||||
stmt = i + j * k - l
|
||||
str = stmt.__str__()
|
||||
self.assertIn("(($1 + ($2 * $3)) - $4)", str)
|
||||
|
||||
def testBoolean(self):
|
||||
with E.ContextManager():
|
||||
i, j, k, l = list(map(E.Expr, [E.Bindable() for _ in range(4)]))
|
||||
i, j, k, l = list(
|
||||
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(4)]))
|
||||
stmt1 = (i < j) & (j >= k)
|
||||
stmt2 = ~(stmt1 | (k == l))
|
||||
str = stmt2.__str__()
|
||||
self.assertIn("~((($1 < $2) && ($2 >= $3)) || ($3 == $4))", str)
|
||||
# Note that "a | b" is currently implemented as ~(~a && ~b) and "~a" is
|
||||
# currently implemented as "constant 1 - a", which leads to this
|
||||
# expression.
|
||||
self.assertIn(
|
||||
"(constant({value: 1}) - (constant({value: 1}) - ((constant({value: 1}) - (($1 < $2) && ($2 >= $3))) && (constant({value: 1}) - ($3 == $4)))))",
|
||||
str)
|
||||
|
||||
def testSelect(self):
|
||||
with E.ContextManager():
|
||||
i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)]))
|
||||
i, j, k = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
|
||||
stmt = E.Select(i > j, i, j)
|
||||
str = stmt.__str__()
|
||||
self.assertIn("select(($1 > $2), $1, $2)", str)
|
||||
|
||||
def testBlock(self):
|
||||
with E.ContextManager():
|
||||
i, j = list(map(E.Expr, [E.Bindable() for _ in range(2)]))
|
||||
i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)]))
|
||||
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
|
||||
str = stmt.__str__()
|
||||
self.assertIn("^bb:", str)
|
||||
|
@ -175,16 +196,14 @@ class EdscTest(unittest.TestCase):
|
|||
self.assertIn("constant 123 : index", str)
|
||||
|
||||
def testMLIRBooleanEmission(self):
|
||||
module = E.MLIRModule()
|
||||
t = module.make_scalar_type("i", 1)
|
||||
m = module.make_memref_type(t, [10]) # i1 tensor
|
||||
f = module.make_function("mkbooltensor", [m, m], [])
|
||||
m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
|
||||
f = self.module.make_function("mkbooltensor", [m, m], [])
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
||||
i = E.Expr(E.Bindable())
|
||||
j = E.Expr(E.Bindable())
|
||||
k = E.Expr(E.Bindable())
|
||||
i = E.Expr(E.Bindable(self.indexType))
|
||||
j = E.Expr(E.Bindable(self.indexType))
|
||||
k = E.Expr(E.Bindable(self.indexType))
|
||||
idxs = [i, j, k]
|
||||
zero = emitter.bind_constant_index(0)
|
||||
one = emitter.bind_constant_index(1)
|
||||
|
@ -201,17 +220,13 @@ class EdscTest(unittest.TestCase):
|
|||
emitter.emit_inplace(loop)
|
||||
# str = f.__str__()
|
||||
# print(str)
|
||||
module.compile()
|
||||
self.assertNotEqual(module.get_engine_address(), 0)
|
||||
self.module.compile()
|
||||
self.assertNotEqual(self.module.get_engine_address(), 0)
|
||||
|
||||
# TODO(ntv): support symbolic For bounds with EDSCs
|
||||
def testMLIREmission(self):
|
||||
shape = [3, 4, 5]
|
||||
module = E.MLIRModule()
|
||||
index = module.make_scalar_type("index")
|
||||
t = module.make_scalar_type("f32")
|
||||
m = module.make_memref_type(t, shape)
|
||||
f = module.make_function("copy", [m, m], [])
|
||||
m = self.module.make_memref_type(self.f32Type, shape)
|
||||
f = self.module.make_function("copy", [m, m], [])
|
||||
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
|
@ -220,7 +235,8 @@ class EdscTest(unittest.TestCase):
|
|||
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
||||
M, N, O = emitter.bind_indexed_shape(input)
|
||||
|
||||
ivs = list(map(E.Expr, [E.Bindable() for _ in range(len(shape))]))
|
||||
ivs = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(len(shape))]))
|
||||
lbs = [zero, zero, zero]
|
||||
ubs = [M, N, O]
|
||||
steps = [one, one, one]
|
||||
|
@ -237,8 +253,9 @@ class EdscTest(unittest.TestCase):
|
|||
self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""",
|
||||
str)
|
||||
|
||||
module.compile()
|
||||
self.assertNotEqual(module.get_engine_address(), 0)
|
||||
self.module.compile()
|
||||
self.assertNotEqual(self.module.get_engine_address(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -24,7 +24,10 @@ class EdscTest(unittest.TestCase):
|
|||
emitter = E.MLIRFunctionEmitter(f)
|
||||
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
||||
lbs, ubs, steps = emitter.bind_indexed_view(input)
|
||||
i, *ivs, j = list(map(E.Expr, [E.Bindable() for _ in range(len(shape))]))
|
||||
i, *ivs, j = list(
|
||||
map(E.Expr,
|
||||
[E.Bindable(module.make_index_type()) for _ in range(len(shape))
|
||||
]))
|
||||
|
||||
# n-D type and rank agnostic copy-transpose-first-last (where n >= 2).
|
||||
loop = E.Block([
|
||||
|
|
|
@ -106,6 +106,9 @@ mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
|
|||
mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
|
||||
mlir_type_list_t outputs);
|
||||
|
||||
/// Returns an `mlir::IndexType`.
|
||||
mlir_type_t makeIndexType(mlir_context_t context);
|
||||
|
||||
/// Returns the arity of `function`.
|
||||
unsigned getFunctionArity(mlir_func_t function);
|
||||
|
||||
|
@ -189,7 +192,7 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
|||
edsc_expr_list_t *resultSteps);
|
||||
|
||||
/// Returns an opaque expression for an mlir::edsc::Expr.
|
||||
edsc_expr_t makeBindable();
|
||||
edsc_expr_t makeBindable(mlir_type_t type);
|
||||
|
||||
/// Returns an opaque expression for an mlir::edsc::Stmt.
|
||||
edsc_stmt_t makeStmt(edsc_expr_t e);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#define MLIR_LIB_EDSC_TYPES_H_
|
||||
|
||||
#include "mlir-c/Core.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
|
@ -36,6 +37,7 @@
|
|||
namespace mlir {
|
||||
|
||||
class MLIRContext;
|
||||
class FuncBuilder;
|
||||
|
||||
namespace edsc {
|
||||
namespace detail {
|
||||
|
@ -72,40 +74,14 @@ enum class ExprKind {
|
|||
LAST_BINDABLE_EXPR = Unbound,
|
||||
FIRST_NON_BINDABLE_EXPR = 200,
|
||||
FIRST_UNARY_EXPR = FIRST_NON_BINDABLE_EXPR,
|
||||
Dealloc = FIRST_UNARY_EXPR,
|
||||
Negate,
|
||||
LAST_UNARY_EXPR = Negate,
|
||||
LAST_UNARY_EXPR = FIRST_UNARY_EXPR,
|
||||
FIRST_BINARY_EXPR = 300,
|
||||
Add = FIRST_BINARY_EXPR,
|
||||
Sub,
|
||||
Mul,
|
||||
Div,
|
||||
AddEQ,
|
||||
SubEQ,
|
||||
MulEQ,
|
||||
DivEQ,
|
||||
GE,
|
||||
GT,
|
||||
LE,
|
||||
LT,
|
||||
EQ,
|
||||
NE,
|
||||
And,
|
||||
Or,
|
||||
LAST_BINARY_EXPR = Or,
|
||||
LAST_BINARY_EXPR = FIRST_BINARY_EXPR,
|
||||
FIRST_TERNARY_EXPR = 400,
|
||||
Select = FIRST_TERNARY_EXPR,
|
||||
IfThenElse,
|
||||
LAST_TERNARY_EXPR = IfThenElse,
|
||||
FIRST_VARIADIC_EXPR = 500,
|
||||
Alloc = FIRST_VARIADIC_EXPR, // Variadic because takes multiple dynamic shape
|
||||
// values.
|
||||
Load,
|
||||
Store,
|
||||
VectorTypeCast, // Variadic because takes a type and anything taking a type
|
||||
// is variadic for now.
|
||||
Return,
|
||||
LAST_VARIADIC_EXPR = Return,
|
||||
LAST_VARIADIC_EXPR = FIRST_VARIADIC_EXPR,
|
||||
FIRST_STMT_BLOCK_LIKE_EXPR = 600,
|
||||
For = FIRST_STMT_BLOCK_LIKE_EXPR,
|
||||
LAST_STMT_BLOCK_LIKE_EXPR = For,
|
||||
|
@ -164,7 +140,7 @@ public:
|
|||
return allocator;
|
||||
}
|
||||
|
||||
Expr();
|
||||
explicit Expr(Type type);
|
||||
/* implicit */ Expr(ImplType *storage) : storage(storage) {}
|
||||
explicit Expr(edsc_expr_t expr)
|
||||
: storage(reinterpret_cast<ImplType *>(expr)) {}
|
||||
|
@ -180,6 +156,20 @@ public:
|
|||
/// Returns the classification for this type.
|
||||
ExprKind getKind() const;
|
||||
unsigned getId() const;
|
||||
StringRef getName() const;
|
||||
|
||||
/// Returns the types of the values this expression produces.
|
||||
ArrayRef<Type> getResultTypes() const;
|
||||
|
||||
/// Returns the list of expressions used as arguments of this expression.
|
||||
ArrayRef<Expr> getChildExpressions() const;
|
||||
|
||||
/// Returns the list of attributes of this expression.
|
||||
ArrayRef<NamedAttribute> getAttributes() const;
|
||||
|
||||
/// Build the IR corresponding to this expression.
|
||||
SmallVector<Value *, 4>
|
||||
build(FuncBuilder &b, const llvm::DenseMap<Expr, Value *> &ssaBindings) const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
@ -216,7 +206,7 @@ private:
|
|||
struct UnaryExpr : public Expr {
|
||||
friend class Expr;
|
||||
|
||||
UnaryExpr(ExprKind kind, Expr expr);
|
||||
UnaryExpr(StringRef name, Expr expr);
|
||||
Expr getExpr() const;
|
||||
|
||||
protected:
|
||||
|
@ -227,7 +217,8 @@ protected:
|
|||
|
||||
struct BinaryExpr : public Expr {
|
||||
friend class Expr;
|
||||
BinaryExpr(ExprKind kind, Expr lhs, Expr rhs);
|
||||
BinaryExpr(StringRef name, Type result, Expr lhs, Expr rhs,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
Expr getLHS() const;
|
||||
Expr getRHS() const;
|
||||
|
||||
|
@ -239,7 +230,7 @@ protected:
|
|||
|
||||
struct TernaryExpr : public Expr {
|
||||
friend class Expr;
|
||||
TernaryExpr(ExprKind kind, Expr cond, Expr lhs, Expr rhs);
|
||||
TernaryExpr(StringRef name, Expr cond, Expr lhs, Expr rhs);
|
||||
Expr getCond() const;
|
||||
Expr getLHS() const;
|
||||
Expr getRHS() const;
|
||||
|
@ -252,8 +243,9 @@ protected:
|
|||
|
||||
struct VariadicExpr : public Expr {
|
||||
friend class Expr;
|
||||
VariadicExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
|
||||
llvm::ArrayRef<Type> types = {});
|
||||
VariadicExpr(StringRef name, llvm::ArrayRef<Expr> exprs,
|
||||
llvm::ArrayRef<Type> types = {},
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
llvm::ArrayRef<Expr> getExprs() const;
|
||||
llvm::ArrayRef<Type> getTypes() const;
|
||||
|
||||
|
@ -554,7 +546,7 @@ namespace edsc {
|
|||
/// `llvm::SmallVector<Expr, 8> dims(n);` directly because a single
|
||||
/// `Expr` will be default constructed and copied everywhere in the vector.
|
||||
/// Hilarity ensues when trying to bind `Expr` multiple times.
|
||||
llvm::SmallVector<Expr, 8> makeNewExprs(unsigned n);
|
||||
llvm::SmallVector<Expr, 8> makeNewExprs(unsigned n, Type type);
|
||||
template <typename IterTy>
|
||||
llvm::SmallVector<Expr, 8> copyExprs(IterTy begin, IterTy end) {
|
||||
return llvm::SmallVector<Expr, 8>(begin, end);
|
||||
|
|
|
@ -44,13 +44,14 @@ char LowerEDSCTestPass::passID = 0;
|
|||
#include "mlir/EDSC/reference-impl.inc"
|
||||
|
||||
PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
|
||||
// Inject a EDSC-constructed list of blocks.
|
||||
if (f->getName().strref() == "blocks") {
|
||||
using namespace edsc::op;
|
||||
|
||||
FuncBuilder builder(f);
|
||||
edsc::ScopedEDSCContext context;
|
||||
edsc::Expr arg1, arg2, arg3, arg4;
|
||||
auto type = builder.getIntegerType(32);
|
||||
edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type);
|
||||
|
||||
auto b1 =
|
||||
edsc::block({arg1, arg2}, {type, type}, {arg1 + arg2, edsc::Return()});
|
||||
|
@ -73,8 +74,9 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
|
|||
"dynamic_for expected index arguments");
|
||||
}
|
||||
|
||||
Type index = IndexType::get(f->getContext());
|
||||
edsc::ScopedEDSCContext context;
|
||||
edsc::Expr lb, ub, step;
|
||||
edsc::Expr lb(index), ub(index), step(index);
|
||||
auto loop = edsc::For(lb, ub, step, {});
|
||||
edsc::MLIREmitter(&builder, f->getLoc())
|
||||
.bind(edsc::Bindable(lb), f->getArgument(0))
|
||||
|
@ -83,6 +85,7 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
|
|||
.emitStmt(loop);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Inject a EDSC-constructed `for` loop with non-constant bounds that are
|
||||
// obtained from AffineApplyOp (also constructed using EDSC operator
|
||||
// overloads).
|
||||
|
@ -97,8 +100,9 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
|
|||
"dynamic_for expected index arguments");
|
||||
}
|
||||
|
||||
Type index = IndexType::get(f->getContext());
|
||||
edsc::ScopedEDSCContext context;
|
||||
edsc::Expr lb1, lb2, ub1, ub2, step;
|
||||
edsc::Expr lb1(index), lb2(index), ub1(index), ub2(index), step(index);
|
||||
using namespace edsc::op;
|
||||
auto lb = lb1 - lb2;
|
||||
auto ub = ub1 + ub2;
|
||||
|
|
|
@ -145,7 +145,8 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
|
|||
}
|
||||
|
||||
mlir::edsc::MLIREmitter::MLIREmitter(FuncBuilder *builder, Location location)
|
||||
: builder(builder), location(location) {
|
||||
: builder(builder), location(location), zeroIndex(builder->getIndexType()),
|
||||
oneIndex(builder->getIndexType()) {
|
||||
// Build the ubiquitous zero and one at the top of the function.
|
||||
bindConstant<ConstantIndexOp>(Bindable(zeroIndex), 0);
|
||||
bindConstant<ConstantIndexOp>(Bindable(oneIndex), 1);
|
||||
|
@ -166,140 +167,23 @@ MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
|
|||
}
|
||||
|
||||
Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
|
||||
// It is still necessary in case we try to emit a bindable directly
|
||||
// FIXME: make sure isa<Bindable> works and use it below to delegate emission
|
||||
// to Expr::build and remove this, now duplicate, check.
|
||||
auto it = ssaBindings.find(e);
|
||||
if (it != ssaBindings.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Skip bindables, they must have been found already.
|
||||
Value *res = nullptr;
|
||||
if (auto un = e.dyn_cast<UnaryExpr>()) {
|
||||
if (un.getKind() == ExprKind::Dealloc) {
|
||||
builder->create<DeallocOp>(location, emitExpr(un.getExpr()));
|
||||
return nullptr;
|
||||
} else if (un.getKind() == ExprKind::Negate) {
|
||||
auto ctrue = builder->create<mlir::ConstantIntOp>(location, 1,
|
||||
builder->getI1Type());
|
||||
// TODO(dvytin): worth binding constant in ssaBindings in the future?
|
||||
// TODO(dvytin): no need to cast getExpr() to I1?
|
||||
auto val = emitExpr(un.getExpr());
|
||||
assert(val->getType().isInteger(1) &&
|
||||
"Logical Negate expects i1 operand");
|
||||
return sub(builder, location, ctrue, val);
|
||||
}
|
||||
} else if (auto bin = e.dyn_cast<BinaryExpr>()) {
|
||||
auto lhs = bin.getLHS();
|
||||
auto rhs = bin.getRHS();
|
||||
auto *a = emitExpr(lhs);
|
||||
auto *b = emitExpr(rhs);
|
||||
if (!a || !b) {
|
||||
return nullptr;
|
||||
}
|
||||
if (bin.getKind() == ExprKind::Add) {
|
||||
res = add(builder, location, a, b);
|
||||
} else if (bin.getKind() == ExprKind::Sub) {
|
||||
res = sub(builder, location, a, b);
|
||||
} else if (bin.getKind() == ExprKind::Mul) {
|
||||
res = mul(builder, location, a, b);
|
||||
} else if (bin.getKind() == ExprKind::And) {
|
||||
// Operands should both be i1
|
||||
assert(a->getType().isInteger(1) && "Logical And expects i1 LHS");
|
||||
assert(b->getType().isInteger(1) && "Logical And expects i1 RHS");
|
||||
res = mul(builder, location, a, b);
|
||||
} else if (bin.getKind() == ExprKind::Or) {
|
||||
assert(a->getType().isInteger(1) && "Logical Or expects i1 LHS");
|
||||
assert(b->getType().isInteger(1) && "Logical Or expects i1 RHS");
|
||||
// a || b = not (not a && not b)
|
||||
using edsc::op::operator!;
|
||||
using edsc::op::operator&&;
|
||||
res = emitExpr(!(!lhs && !rhs));
|
||||
} // TODO(ntv): signed vs unsiged ??
|
||||
// TODO(ntv): integer vs not ??
|
||||
// TODO(ntv): float cmp
|
||||
else if (bin.getKind() == ExprKind::EQ) {
|
||||
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::EQ, a, b);
|
||||
} else if (bin.getKind() == ExprKind::NE) {
|
||||
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::NE, a, b);
|
||||
} else if (bin.getKind() == ExprKind::LT) {
|
||||
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SLT, a, b);
|
||||
} else if (bin.getKind() == ExprKind::LE) {
|
||||
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SLE, a, b);
|
||||
} else if (bin.getKind() == ExprKind::GT) {
|
||||
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SGT, a, b);
|
||||
} else if (bin.getKind() == ExprKind::GE) {
|
||||
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SGE, a, b);
|
||||
}
|
||||
|
||||
// TODO(ntv): do we want this?
|
||||
// if (res && ((a->type().is_uint() && !b->type().is_uint()) ||
|
||||
// (!a->type().is_uint() && b->type().is_uint()))) {
|
||||
// std::stringstream ss;
|
||||
// ss << "a: " << *a << "\t b: " << *b;
|
||||
// res->getDefiningOperation()->emitWarning(
|
||||
// "Mixing signed and unsigned integers: " + ss.str());
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
if (auto ter = e.dyn_cast<TernaryExpr>()) {
|
||||
if (ter.getKind() == ExprKind::Select) {
|
||||
auto *cond = emitExpr(ter.getCond());
|
||||
auto *lhs = emitExpr(ter.getLHS());
|
||||
auto *rhs = emitExpr(ter.getRHS());
|
||||
if (!cond || !rhs || !lhs) {
|
||||
return nullptr;
|
||||
}
|
||||
res = builder->create<SelectOp>(location, cond, lhs, rhs)->getResult();
|
||||
}
|
||||
}
|
||||
|
||||
if (auto nar = e.dyn_cast<VariadicExpr>()) {
|
||||
if (nar.getKind() == ExprKind::Alloc) {
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
auto types = nar.getTypes();
|
||||
assert(types.size() == 1 && "Expected 1 type");
|
||||
res =
|
||||
builder->create<AllocOp>(location, types[0].cast<MemRefType>(), exprs)
|
||||
->getResult();
|
||||
} else if (nar.getKind() == ExprKind::Load) {
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(!exprs.empty() && "Load requires >= 1 exprs");
|
||||
assert(nar.getTypes().empty() && "Load expects no type");
|
||||
SmallVector<Value *, 8> vals(exprs.begin() + 1, exprs.end());
|
||||
res = builder->create<LoadOp>(location, exprs[0], vals)->getResult();
|
||||
} else if (nar.getKind() == ExprKind::Store) {
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(exprs.size() >= 2 && "Store requires >= 2 exprs");
|
||||
assert(nar.getTypes().empty() && "Store expects no type");
|
||||
SmallVector<Value *, 8> vals(exprs.begin() + 2, exprs.end());
|
||||
builder->create<StoreOp>(location, exprs[0], exprs[1], vals);
|
||||
return nullptr;
|
||||
} else if (nar.getKind() == ExprKind::VectorTypeCast) {
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(exprs.size() == 1 && "Expected 1 expr");
|
||||
auto types = nar.getTypes();
|
||||
assert(types.size() == 1 && "Expected 1 type");
|
||||
res = builder
|
||||
->create<VectorTypeCastOp>(location, exprs[0],
|
||||
types[0].cast<MemRefType>())
|
||||
->getResult();
|
||||
} else if (nar.getKind() == ExprKind::Return) {
|
||||
auto exprs = emitExprs(nar.getExprs());
|
||||
builder->create<ReturnOp>(location, exprs);
|
||||
return nullptr; // no Value* produced and this is fine.
|
||||
}
|
||||
bool expectedEmpty = false;
|
||||
if (e.isa<UnaryExpr>() || e.isa<BinaryExpr>() || e.isa<TernaryExpr>() ||
|
||||
e.isa<VariadicExpr>()) {
|
||||
auto results = e.build(*builder, ssaBindings);
|
||||
assert(results.size() <= 1 && "2+-result exprs are not supported");
|
||||
expectedEmpty = results.empty();
|
||||
if (!results.empty())
|
||||
res = results.front();
|
||||
}
|
||||
|
||||
if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) {
|
||||
|
@ -349,7 +233,7 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
|
|||
}
|
||||
}
|
||||
|
||||
if (!res) {
|
||||
if (!res && !expectedEmpty) {
|
||||
// If we hit here it must mean that the Bindables have not all been bound
|
||||
// properly. Because EDSCs are currently dynamically typed, it becomes a
|
||||
// runtime error.
|
||||
|
@ -386,9 +270,9 @@ void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
|
|||
auto ip = builder->getInsertionPoint();
|
||||
auto *val = emitExpr(stmt.getRHS());
|
||||
if (!val) {
|
||||
assert((stmt.getRHS().getKind() == ExprKind::Dealloc ||
|
||||
stmt.getRHS().getKind() == ExprKind::Store ||
|
||||
stmt.getRHS().getKind() == ExprKind::Return) &&
|
||||
assert((stmt.getRHS().getName() == DeallocOp::getOperationName() ||
|
||||
stmt.getRHS().getName() == StoreOp::getOperationName() ||
|
||||
stmt.getRHS().getName() == ReturnOp::getOperationName()) &&
|
||||
"dealloc, store or return expected as the only 0-result ops");
|
||||
return;
|
||||
}
|
||||
|
@ -491,7 +375,7 @@ mlir::edsc::MLIREmitter::makeBoundFunctionArguments(mlir::Function *function) {
|
|||
for (unsigned pos = 0, npos = function->getNumArguments(); pos < npos;
|
||||
++pos) {
|
||||
auto *arg = function->getArgument(pos);
|
||||
Expr b;
|
||||
Expr b(arg->getType());
|
||||
bind(Bindable(b), arg);
|
||||
res.push_back(Expr(b));
|
||||
}
|
||||
|
@ -502,7 +386,8 @@ SmallVector<edsc::Expr, 8>
|
|||
mlir::edsc::MLIREmitter::makeBoundMemRefShape(Value *memRef) {
|
||||
assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
|
||||
MemRefType memRefType = memRef->getType().cast<MemRefType>();
|
||||
auto memRefSizes = edsc::makeNewExprs(memRefType.getShape().size());
|
||||
auto memRefSizes =
|
||||
edsc::makeNewExprs(memRefType.getShape().size(), builder->getIndexType());
|
||||
auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef);
|
||||
assert(memrefSizeValues.size() == memRefSizes.size());
|
||||
bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
|
||||
|
@ -517,7 +402,7 @@ mlir::edsc::MLIREmitter::makeBoundMemRefView(Value *memRef) {
|
|||
|
||||
SmallVector<edsc::Expr, 8> lbs;
|
||||
lbs.reserve(rank);
|
||||
Expr zero;
|
||||
Expr zero(builder->getIndexType());
|
||||
bindConstant<mlir::ConstantIndexOp>(Bindable(zero), 0);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
lbs.push_back(zero);
|
||||
|
@ -527,7 +412,7 @@ mlir::edsc::MLIREmitter::makeBoundMemRefView(Value *memRef) {
|
|||
|
||||
SmallVector<edsc::Expr, 8> steps;
|
||||
lbs.reserve(rank);
|
||||
Expr one;
|
||||
Expr one(builder->getIndexType());
|
||||
bindConstant<mlir::ConstantIndexOp>(Bindable(one), 1);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
steps.push_back(one);
|
||||
|
@ -545,7 +430,7 @@ mlir::edsc::MLIREmitter::makeBoundMemRefView(Expr boundMemRef) {
|
|||
|
||||
edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Expr b;
|
||||
Expr b(e->getBuilder()->getBF16Type());
|
||||
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
|
||||
e->getBuilder()->getBF16Type());
|
||||
return b;
|
||||
|
@ -553,7 +438,7 @@ edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) {
|
|||
|
||||
edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Expr b;
|
||||
Expr b(e->getBuilder()->getBF16Type());
|
||||
bool unused;
|
||||
mlir::APFloat val(value);
|
||||
val.convert(e->getBuilder()->getF16Type().getFloatSemantics(),
|
||||
|
@ -565,7 +450,7 @@ edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) {
|
|||
|
||||
edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Expr b;
|
||||
Expr b(e->getBuilder()->getF32Type());
|
||||
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
|
||||
e->getBuilder()->getF32Type());
|
||||
return b;
|
||||
|
@ -573,7 +458,7 @@ edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) {
|
|||
|
||||
edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Expr b;
|
||||
Expr b(e->getBuilder()->getF64Type());
|
||||
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
|
||||
e->getBuilder()->getF64Type());
|
||||
return b;
|
||||
|
@ -582,7 +467,7 @@ edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
|
|||
edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
|
||||
unsigned bitwidth) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Expr b;
|
||||
Expr b(e->getBuilder()->getIntegerType(bitwidth));
|
||||
e->bindConstant<mlir::ConstantIntOp>(
|
||||
b, value, e->getBuilder()->getIntegerType(bitwidth));
|
||||
return b;
|
||||
|
@ -590,7 +475,7 @@ edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
|
|||
|
||||
edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Expr b;
|
||||
Expr b(e->getBuilder()->getIndexType());
|
||||
e->bindConstant<mlir::ConstantIndexOp>(Bindable(b), value);
|
||||
return b;
|
||||
}
|
||||
|
@ -618,7 +503,7 @@ edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
|
|||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
assert(pos < f->getNumArguments());
|
||||
auto *arg = *(f->getArguments().begin() + pos);
|
||||
Expr b;
|
||||
Expr b(arg->getType());
|
||||
e->bind(Bindable(b), arg);
|
||||
return Expr(b);
|
||||
}
|
||||
|
@ -630,7 +515,7 @@ void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
|
|||
assert(result->n == f->getNumArguments());
|
||||
for (unsigned pos = 0; pos < result->n; ++pos) {
|
||||
auto *arg = *(f->getArguments().begin() + pos);
|
||||
Expr b;
|
||||
Expr b(arg->getType());
|
||||
e->bind(Bindable(b), arg);
|
||||
result->exprs[pos] = Expr(b);
|
||||
}
|
||||
|
@ -670,9 +555,9 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
|||
assert(resultUbs->n == rank && "Unexpected memref binding results count");
|
||||
assert(resultSteps->n == rank && "Unexpected memref binding results count");
|
||||
auto bindables = e->makeBoundMemRefShape(v);
|
||||
Expr zero;
|
||||
Expr zero(e->getBuilder()->getIndexType());
|
||||
e->bindConstant<mlir::ConstantIndexOp>(zero, 0);
|
||||
Expr one;
|
||||
Expr one(e->getBuilder()->getIndexType());
|
||||
e->bindConstant<mlir::ConstantIndexOp>(one, 1);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
resultLbs->exprs[i] = zero;
|
||||
|
|
|
@ -17,9 +17,16 @@
|
|||
|
||||
#include "mlir/EDSC/Types.h"
|
||||
#include "mlir-c/Core.h"
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
@ -54,16 +61,23 @@ struct ExprStorage {
|
|||
ExprKind kind;
|
||||
unsigned id;
|
||||
|
||||
StringRef opName;
|
||||
ArrayRef<Expr> operands;
|
||||
ArrayRef<Type> resultTypes;
|
||||
ArrayRef<NamedAttribute> attributes;
|
||||
|
||||
ExprStorage(ExprKind kind, ArrayRef<Type> results, ArrayRef<Expr> children,
|
||||
ArrayRef<NamedAttribute> attrs, unsigned exprId = Expr::newId())
|
||||
ExprStorage(ExprKind kind, StringRef name, ArrayRef<Type> results,
|
||||
ArrayRef<Expr> children, ArrayRef<NamedAttribute> attrs,
|
||||
StringRef descr = "", unsigned exprId = Expr::newId())
|
||||
: kind(kind), id(exprId) {
|
||||
operands = copyIntoExprAllocator(children);
|
||||
resultTypes = copyIntoExprAllocator(results);
|
||||
attributes = copyIntoExprAllocator(attrs);
|
||||
if (!name.empty()) {
|
||||
auto nameStorage = Expr::globalAllocator()->Allocate<char>(name.size());
|
||||
std::uninitialized_copy(name.begin(), name.end(), nameStorage);
|
||||
opName = StringRef(nameStorage, name.size());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -101,10 +115,10 @@ mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
|
|||
Expr::globalAllocator() = nullptr;
|
||||
}
|
||||
|
||||
mlir::edsc::Expr::Expr() {
|
||||
mlir::edsc::Expr::Expr(Type type) {
|
||||
// Initialize with placement new.
|
||||
storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
|
||||
new (storage) detail::ExprStorage(ExprKind::Unbound, {}, {}, {});
|
||||
new (storage) detail::ExprStorage(ExprKind::Unbound, "", {type}, {}, {});
|
||||
}
|
||||
|
||||
ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; }
|
||||
|
@ -118,49 +132,173 @@ unsigned &mlir::edsc::Expr::newId() {
|
|||
return ++id;
|
||||
}
|
||||
|
||||
ArrayRef<Type> mlir::edsc::Expr::getResultTypes() const {
|
||||
return storage->resultTypes;
|
||||
}
|
||||
|
||||
ArrayRef<Expr> mlir::edsc::Expr::getChildExpressions() const {
|
||||
return storage->operands;
|
||||
}
|
||||
|
||||
ArrayRef<NamedAttribute> mlir::edsc::Expr::getAttributes() const {
|
||||
return storage->attributes;
|
||||
}
|
||||
|
||||
StringRef mlir::edsc::Expr::getName() const {
|
||||
return static_cast<ImplType *>(storage)->opName;
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4>
|
||||
Expr::build(FuncBuilder &b,
|
||||
const llvm::DenseMap<Expr, Value *> &ssaBindings) const {
|
||||
auto it = ssaBindings.find(*this);
|
||||
if (it != ssaBindings.end())
|
||||
return {it->second};
|
||||
|
||||
auto *impl = static_cast<ImplType *>(storage);
|
||||
auto state = OperationState(b.getContext(), b.getUnknownLoc(), impl->opName);
|
||||
SmallVector<Value *, 4> operandValues;
|
||||
operandValues.reserve(impl->operands.size());
|
||||
for (auto child : impl->operands) {
|
||||
auto subResults = child.build(b, ssaBindings);
|
||||
assert(subResults.size() == 1 &&
|
||||
"expected single-result expression as operand");
|
||||
operandValues.push_back(subResults.front());
|
||||
}
|
||||
state.addOperands(operandValues);
|
||||
state.addTypes(impl->resultTypes);
|
||||
for (const auto &attr : impl->attributes)
|
||||
state.addAttribute(attr.first, attr.second);
|
||||
|
||||
Instruction *inst = b.createOperation(state);
|
||||
return llvm::to_vector<4>(inst->getResults());
|
||||
}
|
||||
|
||||
static Expr createBinaryExpr(
|
||||
Expr lhs, Expr rhs, StringRef intOpName, StringRef floatOpName,
|
||||
std::function<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
|
||||
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
|
||||
"only single-result exprs are supported in operators");
|
||||
auto thisType = lhs.getResultTypes().front();
|
||||
auto thatType = rhs.getResultTypes().front();
|
||||
assert(thisType == thatType && "cannot mix types in operators");
|
||||
StringRef opName;
|
||||
if (thisType.isIndex()) {
|
||||
MLIRContext *context = thisType.getContext();
|
||||
auto d0 = getAffineDimExpr(0, context);
|
||||
auto d1 = getAffineDimExpr(1, context);
|
||||
auto map = AffineMap::get(2, 0, {affCombiner(d0, d1)}, {});
|
||||
auto attr = AffineMapAttr::get(map);
|
||||
auto attrId = Identifier::get("map", context);
|
||||
auto namedAttr = NamedAttribute{attrId, attr};
|
||||
return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)},
|
||||
{namedAttr});
|
||||
} else if (thisType.isa<IntegerType>()) {
|
||||
opName = intOpName;
|
||||
} else if (thisType.isa<FloatType>()) {
|
||||
opName = floatOpName;
|
||||
} else if (auto aggregateType = thisType.dyn_cast<VectorOrTensorType>()) {
|
||||
if (aggregateType.getElementType().isa<IntegerType>())
|
||||
opName = intOpName;
|
||||
else if (aggregateType.getElementType().isa<FloatType>())
|
||||
opName = floatOpName;
|
||||
}
|
||||
if (!opName.empty())
|
||||
return BinaryExpr(opName, thisType, lhs, rhs);
|
||||
|
||||
llvm_unreachable("failed to create an Expr");
|
||||
}
|
||||
|
||||
Expr mlir::edsc::op::operator+(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::Add, lhs, rhs);
|
||||
return createBinaryExpr(lhs, rhs, "addi", "addf",
|
||||
[](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
|
||||
}
|
||||
Expr mlir::edsc::op::operator-(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::Sub, lhs, rhs);
|
||||
return createBinaryExpr(lhs, rhs, "subi", "subf",
|
||||
[](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
|
||||
}
|
||||
Expr mlir::edsc::op::operator*(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::Mul, lhs, rhs);
|
||||
return createBinaryExpr(lhs, rhs, "muli", "mulf",
|
||||
[](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
|
||||
}
|
||||
|
||||
static Expr createComparisonExpr(CmpIPredicate predicate, Expr lhs, Expr rhs) {
|
||||
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
|
||||
"only single-result exprs are supported in operators");
|
||||
auto lhsType = lhs.getResultTypes().front();
|
||||
auto rhsType = rhs.getResultTypes().front();
|
||||
assert(lhsType == rhsType && "cannot mix types in operators");
|
||||
assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) &&
|
||||
"only integer comparisons are supported");
|
||||
|
||||
MLIRContext *context = lhsType.getContext();
|
||||
auto attr = IntegerAttr::get(IndexType::get(context),
|
||||
static_cast<int64_t>(predicate));
|
||||
auto attrId = Identifier::get("predicate", context);
|
||||
auto namedAttr = NamedAttribute{attrId, attr};
|
||||
|
||||
return BinaryExpr("cmpi", IntegerType::get(1, context), lhs, rhs,
|
||||
{namedAttr});
|
||||
}
|
||||
|
||||
Expr mlir::edsc::op::operator==(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::EQ, lhs, rhs);
|
||||
return createComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator!=(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::NE, lhs, rhs);
|
||||
return createComparisonExpr(CmpIPredicate::NE, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator<(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::LT, lhs, rhs);
|
||||
// TODO(ntv,zinenko): signed by default, how about unsigned?
|
||||
return createComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator<=(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::LE, lhs, rhs);
|
||||
return createComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator>(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::GT, lhs, rhs);
|
||||
return createComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator>=(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::GE, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator&&(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::And, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator||(Expr lhs, Expr rhs) {
|
||||
return BinaryExpr(ExprKind::Or, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator!(Expr expr) {
|
||||
return UnaryExpr(ExprKind::Negate, expr);
|
||||
return createComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
|
||||
}
|
||||
|
||||
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) {
|
||||
Expr mlir::edsc::op::operator&&(Expr lhs, Expr rhs) {
|
||||
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
|
||||
"expected single-result exprs");
|
||||
auto thisType = lhs.getResultTypes().front();
|
||||
auto thatType = rhs.getResultTypes().front();
|
||||
assert(thisType.isInteger(1) && thatType.isInteger(1) &&
|
||||
"logical And expects i1");
|
||||
return BinaryExpr("muli", thisType, lhs, rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator||(Expr lhs, Expr rhs) {
|
||||
// There is not support for bitwise operations, so we emulate logical 'or'
|
||||
// lhs || rhs
|
||||
// as
|
||||
// !(!lhs && !rhs).
|
||||
using namespace edsc::op;
|
||||
return !(!lhs && !rhs);
|
||||
}
|
||||
Expr mlir::edsc::op::operator!(Expr expr) {
|
||||
assert(expr.getResultTypes().size() == 1 && "expected single-result exprs");
|
||||
auto thisType = expr.getResultTypes().front();
|
||||
assert(thisType.isInteger(1) && "logical Not expects i1");
|
||||
MLIRContext *context = thisType.getContext();
|
||||
|
||||
// Create constant 1 expression.s
|
||||
auto attr = IntegerAttr::get(thisType, 1);
|
||||
auto attrId = Identifier::get("value", context);
|
||||
auto namedAttr = NamedAttribute{attrId, attr};
|
||||
auto cstOne = VariadicExpr("constant", {}, thisType, {namedAttr});
|
||||
|
||||
// Emulate negation as (1 - x) : i1
|
||||
return cstOne - expr;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n, Type type) {
|
||||
llvm::SmallVector<Expr, 8> res;
|
||||
res.reserve(n);
|
||||
for (auto i = 0; i < n; ++i) {
|
||||
res.push_back(Expr());
|
||||
res.push_back(Expr(type));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -183,15 +321,15 @@ static void fillStmts(edsc_stmt_list_t enclosedStmts,
|
|||
}
|
||||
|
||||
Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
|
||||
return VariadicExpr(ExprKind::Alloc, sizes, memrefType);
|
||||
return VariadicExpr("alloc", sizes, memrefType);
|
||||
}
|
||||
|
||||
Expr mlir::edsc::dealloc(Expr memref) {
|
||||
return UnaryExpr(ExprKind::Dealloc, memref);
|
||||
}
|
||||
Expr mlir::edsc::dealloc(Expr memref) { return UnaryExpr("dealloc", memref); }
|
||||
|
||||
Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
|
||||
Expr idx;
|
||||
assert(lb.getResultTypes().size() == 1 && "expected single-result bounds");
|
||||
auto type = lb.getResultTypes().front();
|
||||
Expr idx(type);
|
||||
return For(Bindable(idx), lb, ub, step, stmts);
|
||||
}
|
||||
|
||||
|
@ -249,10 +387,14 @@ edsc_block_t Block(edsc_stmt_list_t enclosedStmts) {
|
|||
}
|
||||
|
||||
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
|
||||
assert(m.getResultTypes().size() == 1 && "expected single-result expr");
|
||||
auto type = m.getResultTypes().front().dyn_cast<MemRefType>();
|
||||
assert(type && "expected memref type");
|
||||
|
||||
SmallVector<Expr, 8> exprs;
|
||||
exprs.push_back(m);
|
||||
exprs.append(indices.begin(), indices.end());
|
||||
return VariadicExpr(ExprKind::Load, exprs);
|
||||
return VariadicExpr("load", exprs, {type.getElementType()});
|
||||
}
|
||||
|
||||
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) {
|
||||
|
@ -267,7 +409,7 @@ Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
|
|||
exprs.push_back(val);
|
||||
exprs.push_back(m);
|
||||
exprs.append(indices.begin(), indices.end());
|
||||
return VariadicExpr(ExprKind::Store, exprs);
|
||||
return VariadicExpr("store", exprs);
|
||||
}
|
||||
|
||||
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
|
||||
|
@ -279,7 +421,7 @@ edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
|
|||
}
|
||||
|
||||
Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) {
|
||||
return TernaryExpr(ExprKind::Select, cond, lhs, rhs);
|
||||
return TernaryExpr("select", cond, lhs, rhs);
|
||||
}
|
||||
|
||||
edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) {
|
||||
|
@ -287,106 +429,210 @@ edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) {
|
|||
}
|
||||
|
||||
Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) {
|
||||
return VariadicExpr(ExprKind::VectorTypeCast, {memrefExpr}, {memrefType});
|
||||
return VariadicExpr("vector_type_cast", {memrefExpr}, {memrefType});
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::Return(ArrayRef<Expr> values) {
|
||||
return VariadicExpr(ExprKind::Return, values);
|
||||
return VariadicExpr("return", values);
|
||||
}
|
||||
|
||||
edsc_stmt_t Return(edsc_expr_list_t values) {
|
||||
return Stmt(Return(makeExprs(values)));
|
||||
}
|
||||
|
||||
static raw_ostream &printBinaryExpr(raw_ostream &os, BinaryExpr e,
|
||||
StringRef infix) {
|
||||
os << '(' << e.getLHS() << ' ' << infix << ' ' << e.getRHS() << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
// Get the operator spelling for pretty-printing the infix form of a
|
||||
// comparison operator.
|
||||
static StringRef getCmpIPredicateInfix(const mlir::edsc::Expr &e) {
|
||||
Attribute predicate;
|
||||
for (const auto &namedAttr : e.getAttributes()) {
|
||||
if (namedAttr.first.is(CmpIOp::getPredicateAttrName())) {
|
||||
predicate = namedAttr.second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(predicate && "expected a predicate in a comparison expr");
|
||||
|
||||
switch (static_cast<CmpIPredicate>(
|
||||
predicate.cast<IntegerAttr>().getValue().getSExtValue())) {
|
||||
case CmpIPredicate::EQ:
|
||||
return "==";
|
||||
case CmpIPredicate::NE:
|
||||
return "!=";
|
||||
case CmpIPredicate::SGT:
|
||||
case CmpIPredicate::UGT:
|
||||
return ">";
|
||||
case CmpIPredicate::SLT:
|
||||
case CmpIPredicate::ULT:
|
||||
return "<";
|
||||
case CmpIPredicate::SGE:
|
||||
case CmpIPredicate::UGE:
|
||||
return ">=";
|
||||
case CmpIPredicate::SLE:
|
||||
case CmpIPredicate::ULE:
|
||||
return "<=";
|
||||
default:
|
||||
llvm_unreachable("unknown predicate");
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
static void printAffineExpr(raw_ostream &os, AffineExpr expr,
|
||||
ArrayRef<Expr> dims, ArrayRef<Expr> symbols) {
|
||||
struct Visitor : public AffineExprVisitor<Visitor> {
|
||||
Visitor(raw_ostream &ostream, ArrayRef<Expr> dimExprs,
|
||||
ArrayRef<Expr> symExprs)
|
||||
: os(ostream), dims(dimExprs), symbols(symExprs) {}
|
||||
raw_ostream &os;
|
||||
ArrayRef<Expr> dims;
|
||||
ArrayRef<Expr> symbols;
|
||||
|
||||
void visitDimExpr(AffineDimExpr dimExpr) {
|
||||
os << dims[dimExpr.getPosition()];
|
||||
}
|
||||
|
||||
void visitSymbolExpr(AffineSymbolExpr symbolExpr) {
|
||||
os << symbols[symbolExpr.getPosition()];
|
||||
}
|
||||
|
||||
void visitConstantExpr(AffineConstantExpr constExpr) {
|
||||
os << constExpr.getValue();
|
||||
}
|
||||
|
||||
void visitBinaryExpr(AffineBinaryOpExpr expr, StringRef infix) {
|
||||
visit(expr.getLHS());
|
||||
os << infix;
|
||||
visit(expr.getRHS());
|
||||
}
|
||||
|
||||
void visitAddExpr(AffineBinaryOpExpr binOp) {
|
||||
visitBinaryExpr(binOp, " + ");
|
||||
}
|
||||
|
||||
void visitMulExpr(AffineBinaryOpExpr binOp) {
|
||||
visitBinaryExpr(binOp, " * ");
|
||||
}
|
||||
|
||||
void visitModExpr(AffineBinaryOpExpr binOp) {
|
||||
visitBinaryExpr(binOp, " % ");
|
||||
}
|
||||
|
||||
void visitCeilDivExpr(AffineBinaryOpExpr binOp) {
|
||||
visitBinaryExpr(binOp, " ceildiv ");
|
||||
}
|
||||
|
||||
void visitFloorDivExpr(AffineBinaryOpExpr binOp) {
|
||||
visitBinaryExpr(binOp, " floordiv ");
|
||||
}
|
||||
};
|
||||
|
||||
Visitor(os, dims, symbols).visit(expr);
|
||||
}
|
||||
|
||||
static void printAffineMap(raw_ostream &os, AffineMap map,
|
||||
ArrayRef<Expr> operands) {
|
||||
auto dims = operands.take_front(map.getNumDims());
|
||||
auto symbols = operands.drop_front(map.getNumDims());
|
||||
assert(map.getNumResults() == 1 &&
|
||||
"only 1-result maps are currently supported");
|
||||
printAffineExpr(os, map.getResult(0), dims, symbols);
|
||||
}
|
||||
|
||||
void printAffineApply(raw_ostream &os, mlir::edsc::Expr e) {
|
||||
Attribute mapAttr;
|
||||
for (const auto &namedAttr : e.getAttributes()) {
|
||||
if (namedAttr.first.is("map")) {
|
||||
mapAttr = namedAttr.second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(mapAttr && "expected a map in an affine apply expr");
|
||||
|
||||
printAffineMap(os, mapAttr.cast<AffineMapAttr>().getValue(),
|
||||
e.getChildExpressions());
|
||||
}
|
||||
|
||||
void mlir::edsc::Expr::print(raw_ostream &os) const {
|
||||
if (auto unbound = this->dyn_cast<Bindable>()) {
|
||||
os << "$" << unbound.getId();
|
||||
return;
|
||||
} else if (auto un = this->dyn_cast<UnaryExpr>()) {
|
||||
switch (un.getKind()) {
|
||||
case ExprKind::Negate:
|
||||
os << "~";
|
||||
break;
|
||||
default: {
|
||||
os << "unknown_unary";
|
||||
}
|
||||
}
|
||||
os << un.getExpr();
|
||||
} else if (auto bin = this->dyn_cast<BinaryExpr>()) {
|
||||
os << "(" << bin.getLHS();
|
||||
switch (bin.getKind()) {
|
||||
case ExprKind::Add:
|
||||
os << " + ";
|
||||
break;
|
||||
case ExprKind::Sub:
|
||||
os << " - ";
|
||||
break;
|
||||
case ExprKind::Mul:
|
||||
os << " * ";
|
||||
break;
|
||||
case ExprKind::Div:
|
||||
os << " / ";
|
||||
break;
|
||||
case ExprKind::LT:
|
||||
os << " < ";
|
||||
break;
|
||||
case ExprKind::LE:
|
||||
os << " <= ";
|
||||
break;
|
||||
case ExprKind::GT:
|
||||
os << " > ";
|
||||
break;
|
||||
case ExprKind::GE:
|
||||
os << " >= ";
|
||||
break;
|
||||
case ExprKind::EQ:
|
||||
os << " == ";
|
||||
break;
|
||||
case ExprKind::NE:
|
||||
os << " != ";
|
||||
break;
|
||||
case ExprKind::And:
|
||||
os << " && ";
|
||||
break;
|
||||
case ExprKind::Or:
|
||||
os << " || ";
|
||||
break;
|
||||
default: {
|
||||
os << "unknown_binary";
|
||||
}
|
||||
}
|
||||
os << bin.getRHS() << ")";
|
||||
|
||||
// Handle known binary ops with pretty infix forms.
|
||||
if (auto binExpr = this->dyn_cast<BinaryExpr>()) {
|
||||
StringRef name = getName();
|
||||
StringRef infix;
|
||||
if (name == AddIOp::getOperationName() ||
|
||||
name == AddFOp::getOperationName())
|
||||
infix = "+";
|
||||
else if (name == SubIOp::getOperationName() ||
|
||||
name == SubFOp::getOperationName())
|
||||
infix = "-";
|
||||
else if (name == MulIOp::getOperationName() ||
|
||||
name == MulFOp::getOperationName())
|
||||
infix = binExpr.getResultTypes().front().isInteger(1) ? "&&" : "*";
|
||||
else if (name == DivIUOp::getOperationName() ||
|
||||
name == DivISOp::getOperationName())
|
||||
infix = "/";
|
||||
else if (name == RemIUOp::getOperationName() ||
|
||||
name == RemISOp::getOperationName())
|
||||
infix = "%";
|
||||
else if (name == CmpIOp::getOperationName())
|
||||
infix = getCmpIPredicateInfix(*this);
|
||||
|
||||
if (!infix.empty()) {
|
||||
printBinaryExpr(os, binExpr, infix);
|
||||
return;
|
||||
} else if (auto ter = this->dyn_cast<TernaryExpr>()) {
|
||||
switch (ter.getKind()) {
|
||||
case ExprKind::Select:
|
||||
os << "select(" << ter.getCond() << ", " << ter.getLHS() << ", "
|
||||
<< ter.getRHS() << ")";
|
||||
return;
|
||||
default: {
|
||||
os << "unknown_ternary";
|
||||
}
|
||||
}
|
||||
} else if (auto nar = this->dyn_cast<VariadicExpr>()) {
|
||||
auto exprs = nar.getExprs();
|
||||
switch (nar.getKind()) {
|
||||
case ExprKind::Load:
|
||||
os << "load(" << exprs[0] << "[";
|
||||
interleaveComma(ArrayRef<Expr>(exprs.begin() + 1, exprs.size() - 1), os);
|
||||
|
||||
// Handle known variadic ops with pretty forms.
|
||||
if (auto narExpr = this->dyn_cast<VariadicExpr>()) {
|
||||
StringRef name = getName();
|
||||
if (name == LoadOp::getOperationName()) {
|
||||
os << name << '(' << getChildExpressions().front() << '[';
|
||||
interleaveComma(getChildExpressions().drop_front(), os);
|
||||
os << "])";
|
||||
return;
|
||||
case ExprKind::Store:
|
||||
os << "store(" << exprs[0] << ", " << exprs[1] << "[";
|
||||
interleaveComma(ArrayRef<Expr>(exprs.begin() + 2, exprs.size() - 2), os);
|
||||
}
|
||||
if (name == StoreOp::getOperationName()) {
|
||||
os << name << '(' << getChildExpressions().front() << ", "
|
||||
<< getChildExpressions()[1] << '[';
|
||||
interleaveComma(getChildExpressions().drop_front(2), os);
|
||||
os << "])";
|
||||
return;
|
||||
case ExprKind::Return:
|
||||
interleaveComma(exprs, os);
|
||||
}
|
||||
if (name == AffineApplyOp::getOperationName()) {
|
||||
os << '(';
|
||||
printAffineApply(os, *this);
|
||||
os << ')';
|
||||
return;
|
||||
default: {
|
||||
os << "unknown_variadic";
|
||||
}
|
||||
}
|
||||
|
||||
// Handle all other types of ops with a more generic printing form.
|
||||
if (this->isa<UnaryExpr>() || this->isa<BinaryExpr>() ||
|
||||
this->isa<TernaryExpr>() || this->isa<VariadicExpr>()) {
|
||||
os << (getName().empty() ? "##unknown##" : getName()) << '(';
|
||||
interleaveComma(getChildExpressions(), os);
|
||||
auto attrs = getAttributes();
|
||||
if (!attrs.empty()) {
|
||||
os << '{';
|
||||
interleave(
|
||||
attrs,
|
||||
[&os](const NamedAttribute &attr) {
|
||||
os << attr.first.strref() << ": " << attr.second;
|
||||
},
|
||||
[&os]() { os << ", "; });
|
||||
os << '}';
|
||||
}
|
||||
os << ')';
|
||||
return;
|
||||
} else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) {
|
||||
auto exprs = stmtLikeExpr.getExprs();
|
||||
switch (stmtLikeExpr.getKind()) {
|
||||
|
@ -419,21 +665,26 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
|
|||
return os;
|
||||
}
|
||||
|
||||
edsc_expr_t makeBindable() { return Bindable(Expr()); }
|
||||
edsc_expr_t makeBindable(mlir_type_t type) {
|
||||
return Bindable(Expr(Type(reinterpret_cast<const Type::ImplType *>(type))));
|
||||
}
|
||||
|
||||
mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
|
||||
mlir::edsc::UnaryExpr::UnaryExpr(StringRef name, Expr expr)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::ExprStorage(kind, {}, {expr}, {});
|
||||
new (storage)
|
||||
detail::ExprStorage(ExprKind::FIRST_UNARY_EXPR, name, {}, {expr}, {});
|
||||
}
|
||||
Expr mlir::edsc::UnaryExpr::getExpr() const {
|
||||
return static_cast<ImplType *>(storage)->operands.front();
|
||||
}
|
||||
|
||||
mlir::edsc::BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs)
|
||||
mlir::edsc::BinaryExpr::BinaryExpr(StringRef name, Type result, Expr lhs,
|
||||
Expr rhs, ArrayRef<NamedAttribute> attrs)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::ExprStorage(kind, {}, {lhs, rhs}, {});
|
||||
new (storage) detail::ExprStorage(ExprKind::FIRST_BINARY_EXPR, name, {result},
|
||||
{lhs, rhs}, attrs);
|
||||
}
|
||||
Expr mlir::edsc::BinaryExpr::getLHS() const {
|
||||
return static_cast<ImplType *>(storage)->operands.front();
|
||||
|
@ -442,11 +693,15 @@ Expr mlir::edsc::BinaryExpr::getRHS() const {
|
|||
return static_cast<ImplType *>(storage)->operands.back();
|
||||
}
|
||||
|
||||
mlir::edsc::TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs,
|
||||
mlir::edsc::TernaryExpr::TernaryExpr(StringRef name, Expr cond, Expr lhs,
|
||||
Expr rhs)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::ExprStorage(kind, {}, {cond, lhs, rhs}, {});
|
||||
assert(lhs.getResultTypes().size() == 1 && "expected single-result expr");
|
||||
assert(rhs.getResultTypes().size() == 1 && "expected single-result expr");
|
||||
new (storage)
|
||||
detail::ExprStorage(ExprKind::FIRST_TERNARY_EXPR, name,
|
||||
{lhs.getResultTypes().front()}, {cond, lhs, rhs}, {});
|
||||
}
|
||||
Expr mlir::edsc::TernaryExpr::getCond() const {
|
||||
return static_cast<ImplType *>(storage)->operands[0];
|
||||
|
@ -458,11 +713,13 @@ Expr mlir::edsc::TernaryExpr::getRHS() const {
|
|||
return static_cast<ImplType *>(storage)->operands[2];
|
||||
}
|
||||
|
||||
mlir::edsc::VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
|
||||
ArrayRef<Type> types)
|
||||
mlir::edsc::VariadicExpr::VariadicExpr(StringRef name, ArrayRef<Expr> exprs,
|
||||
ArrayRef<Type> types,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::ExprStorage(kind, types, exprs, {});
|
||||
new (storage) detail::ExprStorage(ExprKind::FIRST_VARIADIC_EXPR, name, types,
|
||||
exprs, attrs);
|
||||
}
|
||||
ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const {
|
||||
return static_cast<ImplType *>(storage)->operands;
|
||||
|
@ -476,7 +733,7 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
|
|||
ArrayRef<Type> types)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::ExprStorage(kind, types, exprs, {});
|
||||
new (storage) detail::ExprStorage(kind, "", types, exprs, {});
|
||||
}
|
||||
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
|
||||
return static_cast<ImplType *>(storage)->operands;
|
||||
|
@ -494,8 +751,9 @@ mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
|
|||
lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())};
|
||||
}
|
||||
|
||||
// Statement with enclosed statements does not have a LHS.
|
||||
mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
|
||||
: Stmt(Bindable(Expr()), rhs, enclosedStmts) {}
|
||||
: Stmt(Bindable(Expr(Type())), rhs, enclosedStmts) {}
|
||||
|
||||
edsc_stmt_t makeStmt(edsc_expr_t e) {
|
||||
assert(e && "unexpected empty expression");
|
||||
|
@ -503,7 +761,7 @@ edsc_stmt_t makeStmt(edsc_expr_t e) {
|
|||
}
|
||||
|
||||
Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
|
||||
Stmt res(Bindable(Expr()), expr, {});
|
||||
Stmt res(Bindable(Expr(Type())), expr, {});
|
||||
std::swap(res.storage, this->storage);
|
||||
return *this;
|
||||
}
|
||||
|
@ -678,6 +936,12 @@ mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
|
|||
return mlir_type_t{ft.getAsOpaquePointer()};
|
||||
}
|
||||
|
||||
mlir_type_t makeIndexType(mlir_context_t context) {
|
||||
auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
|
||||
auto type = mlir::IndexType::get(ctx);
|
||||
return mlir_type_t{type.getAsOpaquePointer()};
|
||||
}
|
||||
|
||||
unsigned getFunctionArity(mlir_func_t function) {
|
||||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
return f->getNumArguments();
|
||||
|
|
|
@ -167,7 +167,7 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
|
|||
|
||||
// Create new Exprs for ivs, they will be bound at `For` Stmt
|
||||
// construction.
|
||||
auto ivs = makeNewExprs(vectorShape.size());
|
||||
auto ivs = makeNewExprs(vectorShape.size(), this->rewriter->getIndexType());
|
||||
|
||||
// Create and bind Exprs to refer to the Value for memref sizes.
|
||||
auto memRefSizes = emitter.makeBoundMemRefShape(transfer->getMemRef());
|
||||
|
@ -222,9 +222,9 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
|
|||
|
||||
// Create the proper bindables for lbs, ubs and steps. Additionally, if we
|
||||
// recorded a coalescing index, permute the loop informations.
|
||||
auto lbs = makeNewExprs(ivs.size());
|
||||
auto lbs = makeNewExprs(ivs.size(), this->rewriter->getIndexType());
|
||||
auto ubs = copyExprs(vectorSizes);
|
||||
auto steps = makeNewExprs(ivs.size());
|
||||
auto steps = makeNewExprs(ivs.size(), this->rewriter->getIndexType());
|
||||
if (coalescingIndex >= 0) {
|
||||
std::swap(ivs[coalescingIndex], ivs.back());
|
||||
std::swap(lbs[coalescingIndex], lbs.back());
|
||||
|
@ -257,11 +257,14 @@ VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
|
|||
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0)),
|
||||
vectorMemRefType(MemRefType::get({1}, vectorType, {}, 0)),
|
||||
emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())),
|
||||
vectorSizes(edsc::makeNewExprs(vectorShape.size())), zero(emitter.zero()),
|
||||
one(emitter.one()) {
|
||||
vectorSizes(
|
||||
edsc::makeNewExprs(vectorShape.size(), rewriter->getIndexType())),
|
||||
zero(emitter.zero()), one(emitter.one()),
|
||||
scalarMemRef(transfer->getMemRefType()) {
|
||||
// Bind the Bindable.
|
||||
SmallVector<Value *, 8> transferIndices(transfer->getIndices());
|
||||
accesses = edsc::makeNewExprs(transferIndices.size());
|
||||
accesses = edsc::makeNewExprs(transferIndices.size(),
|
||||
this->rewriter->getIndexType());
|
||||
emitter.bind(edsc::Bindable(scalarMemRef), transfer->getMemRef())
|
||||
.template bindZipRangeConstants<ConstantIndexOp>(
|
||||
llvm::zip(vectorSizes, vectorShape))
|
||||
|
@ -321,7 +324,11 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
|||
auto &lbs = accessInfo.lowerBoundsExprs;
|
||||
auto &ubs = accessInfo.upperBoundsExprs;
|
||||
auto &steps = accessInfo.stepExprs;
|
||||
Expr scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
|
||||
|
||||
auto vectorType = this->transfer->getVectorType();
|
||||
auto scalarType = this->transfer->getMemRefType().getElementType();
|
||||
|
||||
Expr scalarValue(scalarType), vectorValue(vectorType), tmpAlloc(tmpMemRefType), tmpDealloc(Type{}), vectorView(vectorMemRefType);
|
||||
auto block = edsc::block({
|
||||
tmpAlloc = alloc(tmpMemRefType),
|
||||
vectorView = vector_type_cast(Expr(tmpAlloc), vectorMemRefType),
|
||||
|
@ -368,7 +375,7 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
|
|||
auto accessInfo = makeVectorTransferAccessInfo();
|
||||
|
||||
// Bind vector value for the vector_transfer_write.
|
||||
Expr vectorValue;
|
||||
Expr vectorValue(transfer->getVectorType());
|
||||
emitter.bind(Bindable(vectorValue), transfer->getVector());
|
||||
|
||||
// clang-format off
|
||||
|
@ -376,7 +383,8 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
|
|||
auto &lbs = accessInfo.lowerBoundsExprs;
|
||||
auto &ubs = accessInfo.upperBoundsExprs;
|
||||
auto &steps = accessInfo.stepExprs;
|
||||
Expr scalarValue, tmpAlloc, tmpDealloc, vectorView;
|
||||
auto scalarType = tmpMemRefType.getElementType();
|
||||
Expr scalarValue(scalarType), tmpAlloc(tmpMemRefType), tmpDealloc(Type{}), vectorView(vectorMemRefType);
|
||||
auto block = edsc::block({
|
||||
tmpAlloc = alloc(tmpMemRefType),
|
||||
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
|
||||
|
|
|
@ -11,7 +11,7 @@ def X_AddOp : Op<"x.add">,
|
|||
// TODO: extract referenceImplementation to Op.
|
||||
// TODO: shrink the reference implementation
|
||||
code referenceImplementation = [{
|
||||
auto ivs = makeNewExprs(view_A.rank());
|
||||
auto ivs = makeNewExprs(view_A.rank(), builder.getIndexType());
|
||||
// TODO(jpienaar@): automate the positional/named extraction. Need to be a
|
||||
// bit careful about things memref (from which a "view" can be extracted)
|
||||
// and the rest (see ReferenceImplGen.cpp).
|
||||
|
|
Loading…
Reference in New Issue