EDSC: make Expr typed and extensible

Expose the result types of edsc::Expr, which are now stored for all types of
Exprs and not only for the variadic ones.  Require return types when an Expr is
constructed, if it will ever have some.  An empty return type list is
interpreted as an Expr that does not create a value (e.g. `return` or `store`).

Conceptually, all edss::Exprs are now typed, with the type being a (potentially
empty) tuple of return types.  Unbound expressions and Bindables must now be
constructed with a specific type they will take.  This makes EDSC less
evidently type-polymorphic, but we can still write generic code such as

    Expr sumOfSquares(Expr lhs, Expr rhs) { return lhs * lhs + rhs * rhs; }

and use it to construct different typed expressions as

    sumOfSquares(Bindable(IndexType::get(ctx)), Bindable(IndexType::get(ctx)));
    sumOfSquares(Bindable(FloatType::getF32(ctx)),
                 Bindable(FloatType::getF32(ctx)));

On the positive side, we get the following.
1. We can now perform type checking when constructing Exprs rather than during
   MLIR emission.  Nevertheless, this is still duplicates the Op::verify()
   until we can factor out type checking from that.
2. MLIREmitter is significantly simplified.
3. ExprKind enum is only used for actual kinds of expressions.  Data structures
   are converging with AbstractOperation, and the users can now create a
   VariadicExpr("canonical_op_name", {types}, {exprs}) for any operation, even
   an unregistered one without having to extend the enum and make pervasive
   changes to EDSCs.

On the negative side, we get the following.
1. Typed bindables are more verbose, even in Python.
2. We lose the ability to do print debugging for higher-level EDSC abstractions
   that are implemented as multiple MLIR Ops, for example logical disjunction.

This is the step 2/n towards making EDSC extensible.

***

Move MLIR Op construction from MLIREmitter::emitExpr to Expr::build since Expr
now has sufficient information to build itself.

This is the step 3/n towards making EDSC extensible.

Both of these strive to minimize the amount of irrelevant changes.  In
particular, this introduces more complex pretty-printing for affine and binary
expression to make sure tests continue to pass.  It also relies on string
comparison to identify specific operations that an Expr produces.

PiperOrigin-RevId: 234609882
This commit is contained in:
Alex Zinenko 2019-02-19 08:51:52 -08:00 committed by jpienaar
parent e0fc503896
commit b4dba895a6
10 changed files with 548 additions and 368 deletions

View File

@ -110,6 +110,9 @@ struct PythonMLIRModule {
return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType, return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
int64_list_t{sizes.data(), sizes.size()}); int64_list_t{sizes.data(), sizes.size()});
} }
PythonType makeIndexType() {
return ::makeIndexType(mlir_context_t{&mlirContext});
}
PythonFunction makeFunction(const std::string &name, PythonFunction makeFunction(const std::string &name,
std::vector<PythonType> &inputTypes, std::vector<PythonType> &inputTypes,
std::vector<PythonType> &outputTypes) { std::vector<PythonType> &outputTypes) {
@ -177,7 +180,8 @@ struct PythonExpr {
}; };
struct PythonBindable : public PythonExpr { struct PythonBindable : public PythonExpr {
PythonBindable() : PythonExpr(edsc_expr_t{makeBindable()}) {} explicit PythonBindable(const PythonType &type)
: PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
PythonBindable(PythonExpr expr) : PythonExpr(expr) { PythonBindable(PythonExpr expr) : PythonExpr(expr) {
assert(Expr(expr).isa<Bindable>() && "Expected Bindable"); assert(Expr(expr).isa<Bindable>() && "Expected Bindable");
} }
@ -213,7 +217,6 @@ struct PythonBlock {
}; };
struct PythonIndexed : public edsc_indexed_t { struct PythonIndexed : public edsc_indexed_t {
PythonIndexed() : edsc_indexed_t{makeIndexed(PythonBindable())} {}
PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {} PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {}
PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {} PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {}
operator PythonExpr() { return PythonExpr(base); } operator PythonExpr() { return PythonExpr(base); }
@ -475,6 +478,8 @@ PYBIND11_MODULE(pybind, m) {
.def("make_memref_type", &PythonMLIRModule::makeMemRefType, .def("make_memref_type", &PythonMLIRModule::makeMemRefType,
"Returns an mlir::MemRefType of an elemental scalar. -1 is used to " "Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
"denote symbolic dimensions in the resulting memref shape.") "denote symbolic dimensions in the resulting memref shape.")
.def("make_index_type", &PythonMLIRModule::makeIndexType,
"Returns an mlir::IndexType")
.def("compile", &PythonMLIRModule::compile, .def("compile", &PythonMLIRModule::compile,
"Compiles the mlir::Module to LLVMIR a creates new opaque " "Compiles the mlir::Module to LLVMIR a creates new opaque "
"ExecutionEngine backed by the ORC JIT.") "ExecutionEngine backed by the ORC JIT.")
@ -576,7 +581,7 @@ PYBIND11_MODULE(pybind, m) {
m, "Bindable", m, "Bindable",
"Wrapping class for mlir::edsc::Bindable.\nA Bindable is a special Expr " "Wrapping class for mlir::edsc::Bindable.\nA Bindable is a special Expr "
"that can be bound manually to specific MLIR SSA Values.") "that can be bound manually to specific MLIR SSA Values.")
.def(py::init<>()) .def(py::init<PythonType>())
.def("__str__", &PythonBindable::str); .def("__str__", &PythonBindable::str);
py::class_<PythonStmt>(m, "Stmt", "Wrapping class for mlir::edsc::Stmt.") py::class_<PythonStmt>(m, "Stmt", "Wrapping class for mlir::edsc::Stmt.")
@ -588,7 +593,6 @@ PYBIND11_MODULE(pybind, m) {
m, "Indexed", m, "Indexed",
"Wrapping class for mlir::edsc::Indexed.\nAn Indexed is a wrapper class " "Wrapping class for mlir::edsc::Indexed.\nAn Indexed is a wrapper class "
"that support load and store operations.") "that support load and store operations.")
.def(py::init<>(), R"DOC(Build from fresh Bindable)DOC")
.def(py::init<PythonExpr>(), R"DOC(Build from existing Expr)DOC") .def(py::init<PythonExpr>(), R"DOC(Build from existing Expr)DOC")
.def(py::init<PythonBindable>(), R"DOC(Build from existing Bindable)DOC") .def(py::init<PythonBindable>(), R"DOC(Build from existing Bindable)DOC")
.def( .def(

View File

@ -10,21 +10,30 @@ import google_mlir.bindings.python.pybind as E
class EdscTest(unittest.TestCase): class EdscTest(unittest.TestCase):
def setUp(self):
self.module = E.MLIRModule()
self.boolType = self.module.make_scalar_type("i", 1)
self.i32Type = self.module.make_scalar_type("i", 32)
self.f32Type = self.module.make_scalar_type("f32")
self.indexType = self.module.make_index_type()
def testBindables(self): def testBindables(self):
with E.ContextManager(): with E.ContextManager():
i = E.Expr(E.Bindable()) i = E.Expr(E.Bindable(self.i32Type))
self.assertIn("$1", i.__str__()) self.assertIn("$1", i.__str__())
def testOneExpr(self): def testOneExpr(self):
with E.ContextManager(): with E.ContextManager():
i, lb, ub = list(map(E.Expr, [E.Bindable() for _ in range(3)])) i, lb, ub = list(
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
expr = E.Mul(i, E.Add(lb, ub)) expr = E.Mul(i, E.Add(lb, ub))
str = expr.__str__() str = expr.__str__()
self.assertIn("($1 * ($2 + $3))", str) self.assertIn("($1 * ($2 + $3))", str)
def testOneLoop(self): def testOneLoop(self):
with E.ContextManager(): with E.ContextManager():
i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)])) i, lb, ub, step = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))]) loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
str = loop.__str__() str = loop.__str__()
self.assertIn("for($1 = $2 to $3 step $4) {", str) self.assertIn("for($1 = $2 to $3 step $4) {", str)
@ -32,7 +41,8 @@ class EdscTest(unittest.TestCase):
def testTwoLoops(self): def testTwoLoops(self):
with E.ContextManager(): with E.ContextManager():
i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)])) i, lb, ub, step = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])]) loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])])
str = loop.__str__() str = loop.__str__()
self.assertIn("for($1 = $2 to $3 step $4) {", str) self.assertIn("for($1 = $2 to $3 step $4) {", str)
@ -41,11 +51,12 @@ class EdscTest(unittest.TestCase):
def testNestedLoops(self): def testNestedLoops(self):
with E.ContextManager(): with E.ContextManager():
i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)])) i, lb, ub, step = list(
ivs = list(map(E.Expr, [E.Bindable() for _ in range(4)])) map(E.Expr, [E.Bindable(self.i32Type) for _ in range(4)]))
lbs = list(map(E.Expr, [E.Bindable() for _ in range(4)])) ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
ubs = list(map(E.Expr, [E.Bindable() for _ in range(4)])) lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
steps = list(map(E.Expr, [E.Bindable() for _ in range(4)])) ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(ivs, lbs, ubs, steps, [ loop = E.For(ivs, lbs, ubs, steps, [
E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]), E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]),
]) ])
@ -59,20 +70,23 @@ class EdscTest(unittest.TestCase):
def testIndexed(self): def testIndexed(self):
with E.ContextManager(): with E.ContextManager():
i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)])) i, j, k = list(
A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)])) map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j])) stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
str = stmt.__str__() str = stmt.__str__()
self.assertIn(" = store(", str) self.assertIn(" = store(", str)
def testMatmul(self): def testMatmul(self):
with E.ContextManager(): with E.ContextManager():
ivs = list(map(E.Expr, [E.Bindable() for _ in range(3)])) ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
lbs = list(map(E.Expr, [E.Bindable() for _ in range(3)])) lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
ubs = list(map(E.Expr, [E.Bindable() for _ in range(3)])) ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
steps = list(map(E.Expr, [E.Bindable() for _ in range(3)])) steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
i, j, k = ivs[0], ivs[1], ivs[2] i, j, k = ivs[0], ivs[1], ivs[2]
A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)])) memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
loop = E.For( loop = E.For(
ivs, lbs, ubs, steps, ivs, lbs, ubs, steps,
[C.store([i, j], [C.store([i, j],
@ -85,29 +99,36 @@ class EdscTest(unittest.TestCase):
def testArithmetic(self): def testArithmetic(self):
with E.ContextManager(): with E.ContextManager():
i, j, k, l = list(map(E.Expr, [E.Bindable() for _ in range(4)])) i, j, k, l = list(
map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)]))
stmt = i + j * k - l stmt = i + j * k - l
str = stmt.__str__() str = stmt.__str__()
self.assertIn("(($1 + ($2 * $3)) - $4)", str) self.assertIn("(($1 + ($2 * $3)) - $4)", str)
def testBoolean(self): def testBoolean(self):
with E.ContextManager(): with E.ContextManager():
i, j, k, l = list(map(E.Expr, [E.Bindable() for _ in range(4)])) i, j, k, l = list(
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(4)]))
stmt1 = (i < j) & (j >= k) stmt1 = (i < j) & (j >= k)
stmt2 = ~(stmt1 | (k == l)) stmt2 = ~(stmt1 | (k == l))
str = stmt2.__str__() str = stmt2.__str__()
self.assertIn("~((($1 < $2) && ($2 >= $3)) || ($3 == $4))", str) # Note that "a | b" is currently implemented as ~(~a && ~b) and "~a" is
# currently implemented as "constant 1 - a", which leads to this
# expression.
self.assertIn(
"(constant({value: 1}) - (constant({value: 1}) - ((constant({value: 1}) - (($1 < $2) && ($2 >= $3))) && (constant({value: 1}) - ($3 == $4)))))",
str)
def testSelect(self): def testSelect(self):
with E.ContextManager(): with E.ContextManager():
i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)])) i, j, k = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
stmt = E.Select(i > j, i, j) stmt = E.Select(i > j, i, j)
str = stmt.__str__() str = stmt.__str__()
self.assertIn("select(($1 > $2), $1, $2)", str) self.assertIn("select(($1 > $2), $1, $2)", str)
def testBlock(self): def testBlock(self):
with E.ContextManager(): with E.ContextManager():
i, j = list(map(E.Expr, [E.Bindable() for _ in range(2)])) i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)]))
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)]) stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
str = stmt.__str__() str = stmt.__str__()
self.assertIn("^bb:", str) self.assertIn("^bb:", str)
@ -175,16 +196,14 @@ class EdscTest(unittest.TestCase):
self.assertIn("constant 123 : index", str) self.assertIn("constant 123 : index", str)
def testMLIRBooleanEmission(self): def testMLIRBooleanEmission(self):
module = E.MLIRModule() m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
t = module.make_scalar_type("i", 1) f = self.module.make_function("mkbooltensor", [m, m], [])
m = module.make_memref_type(t, [10]) # i1 tensor
f = module.make_function("mkbooltensor", [m, m], [])
with E.ContextManager(): with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f) emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments())) input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
i = E.Expr(E.Bindable()) i = E.Expr(E.Bindable(self.indexType))
j = E.Expr(E.Bindable()) j = E.Expr(E.Bindable(self.indexType))
k = E.Expr(E.Bindable()) k = E.Expr(E.Bindable(self.indexType))
idxs = [i, j, k] idxs = [i, j, k]
zero = emitter.bind_constant_index(0) zero = emitter.bind_constant_index(0)
one = emitter.bind_constant_index(1) one = emitter.bind_constant_index(1)
@ -201,17 +220,13 @@ class EdscTest(unittest.TestCase):
emitter.emit_inplace(loop) emitter.emit_inplace(loop)
# str = f.__str__() # str = f.__str__()
# print(str) # print(str)
module.compile() self.module.compile()
self.assertNotEqual(module.get_engine_address(), 0) self.assertNotEqual(self.module.get_engine_address(), 0)
# TODO(ntv): support symbolic For bounds with EDSCs
def testMLIREmission(self): def testMLIREmission(self):
shape = [3, 4, 5] shape = [3, 4, 5]
module = E.MLIRModule() m = self.module.make_memref_type(self.f32Type, shape)
index = module.make_scalar_type("index") f = self.module.make_function("copy", [m, m], [])
t = module.make_scalar_type("f32")
m = module.make_memref_type(t, shape)
f = module.make_function("copy", [m, m], [])
with E.ContextManager(): with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f) emitter = E.MLIRFunctionEmitter(f)
@ -220,7 +235,8 @@ class EdscTest(unittest.TestCase):
input, output = list(map(E.Indexed, emitter.bind_function_arguments())) input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
M, N, O = emitter.bind_indexed_shape(input) M, N, O = emitter.bind_indexed_shape(input)
ivs = list(map(E.Expr, [E.Bindable() for _ in range(len(shape))])) ivs = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(len(shape))]))
lbs = [zero, zero, zero] lbs = [zero, zero, zero]
ubs = [M, N, O] ubs = [M, N, O]
steps = [one, one, one] steps = [one, one, one]
@ -237,8 +253,9 @@ class EdscTest(unittest.TestCase):
self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""", self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""",
str) str)
module.compile() self.module.compile()
self.assertNotEqual(module.get_engine_address(), 0) self.assertNotEqual(self.module.get_engine_address(), 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -24,7 +24,10 @@ class EdscTest(unittest.TestCase):
emitter = E.MLIRFunctionEmitter(f) emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments())) input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
lbs, ubs, steps = emitter.bind_indexed_view(input) lbs, ubs, steps = emitter.bind_indexed_view(input)
i, *ivs, j = list(map(E.Expr, [E.Bindable() for _ in range(len(shape))])) i, *ivs, j = list(
map(E.Expr,
[E.Bindable(module.make_index_type()) for _ in range(len(shape))
]))
# n-D type and rank agnostic copy-transpose-first-last (where n >= 2). # n-D type and rank agnostic copy-transpose-first-last (where n >= 2).
loop = E.Block([ loop = E.Block([

View File

@ -106,6 +106,9 @@ mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs, mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
mlir_type_list_t outputs); mlir_type_list_t outputs);
/// Returns an `mlir::IndexType`.
mlir_type_t makeIndexType(mlir_context_t context);
/// Returns the arity of `function`. /// Returns the arity of `function`.
unsigned getFunctionArity(mlir_func_t function); unsigned getFunctionArity(mlir_func_t function);
@ -189,7 +192,7 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
edsc_expr_list_t *resultSteps); edsc_expr_list_t *resultSteps);
/// Returns an opaque expression for an mlir::edsc::Expr. /// Returns an opaque expression for an mlir::edsc::Expr.
edsc_expr_t makeBindable(); edsc_expr_t makeBindable(mlir_type_t type);
/// Returns an opaque expression for an mlir::edsc::Stmt. /// Returns an opaque expression for an mlir::edsc::Stmt.
edsc_stmt_t makeStmt(edsc_expr_t e); edsc_stmt_t makeStmt(edsc_expr_t e);

View File

@ -25,6 +25,7 @@
#define MLIR_LIB_EDSC_TYPES_H_ #define MLIR_LIB_EDSC_TYPES_H_
#include "mlir-c/Core.h" #include "mlir-c/Core.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
@ -36,6 +37,7 @@
namespace mlir { namespace mlir {
class MLIRContext; class MLIRContext;
class FuncBuilder;
namespace edsc { namespace edsc {
namespace detail { namespace detail {
@ -72,40 +74,14 @@ enum class ExprKind {
LAST_BINDABLE_EXPR = Unbound, LAST_BINDABLE_EXPR = Unbound,
FIRST_NON_BINDABLE_EXPR = 200, FIRST_NON_BINDABLE_EXPR = 200,
FIRST_UNARY_EXPR = FIRST_NON_BINDABLE_EXPR, FIRST_UNARY_EXPR = FIRST_NON_BINDABLE_EXPR,
Dealloc = FIRST_UNARY_EXPR, LAST_UNARY_EXPR = FIRST_UNARY_EXPR,
Negate,
LAST_UNARY_EXPR = Negate,
FIRST_BINARY_EXPR = 300, FIRST_BINARY_EXPR = 300,
Add = FIRST_BINARY_EXPR, LAST_BINARY_EXPR = FIRST_BINARY_EXPR,
Sub,
Mul,
Div,
AddEQ,
SubEQ,
MulEQ,
DivEQ,
GE,
GT,
LE,
LT,
EQ,
NE,
And,
Or,
LAST_BINARY_EXPR = Or,
FIRST_TERNARY_EXPR = 400, FIRST_TERNARY_EXPR = 400,
Select = FIRST_TERNARY_EXPR,
IfThenElse, IfThenElse,
LAST_TERNARY_EXPR = IfThenElse, LAST_TERNARY_EXPR = IfThenElse,
FIRST_VARIADIC_EXPR = 500, FIRST_VARIADIC_EXPR = 500,
Alloc = FIRST_VARIADIC_EXPR, // Variadic because takes multiple dynamic shape LAST_VARIADIC_EXPR = FIRST_VARIADIC_EXPR,
// values.
Load,
Store,
VectorTypeCast, // Variadic because takes a type and anything taking a type
// is variadic for now.
Return,
LAST_VARIADIC_EXPR = Return,
FIRST_STMT_BLOCK_LIKE_EXPR = 600, FIRST_STMT_BLOCK_LIKE_EXPR = 600,
For = FIRST_STMT_BLOCK_LIKE_EXPR, For = FIRST_STMT_BLOCK_LIKE_EXPR,
LAST_STMT_BLOCK_LIKE_EXPR = For, LAST_STMT_BLOCK_LIKE_EXPR = For,
@ -164,7 +140,7 @@ public:
return allocator; return allocator;
} }
Expr(); explicit Expr(Type type);
/* implicit */ Expr(ImplType *storage) : storage(storage) {} /* implicit */ Expr(ImplType *storage) : storage(storage) {}
explicit Expr(edsc_expr_t expr) explicit Expr(edsc_expr_t expr)
: storage(reinterpret_cast<ImplType *>(expr)) {} : storage(reinterpret_cast<ImplType *>(expr)) {}
@ -180,6 +156,20 @@ public:
/// Returns the classification for this type. /// Returns the classification for this type.
ExprKind getKind() const; ExprKind getKind() const;
unsigned getId() const; unsigned getId() const;
StringRef getName() const;
/// Returns the types of the values this expression produces.
ArrayRef<Type> getResultTypes() const;
/// Returns the list of expressions used as arguments of this expression.
ArrayRef<Expr> getChildExpressions() const;
/// Returns the list of attributes of this expression.
ArrayRef<NamedAttribute> getAttributes() const;
/// Build the IR corresponding to this expression.
SmallVector<Value *, 4>
build(FuncBuilder &b, const llvm::DenseMap<Expr, Value *> &ssaBindings) const;
void print(raw_ostream &os) const; void print(raw_ostream &os) const;
void dump() const; void dump() const;
@ -216,7 +206,7 @@ private:
struct UnaryExpr : public Expr { struct UnaryExpr : public Expr {
friend class Expr; friend class Expr;
UnaryExpr(ExprKind kind, Expr expr); UnaryExpr(StringRef name, Expr expr);
Expr getExpr() const; Expr getExpr() const;
protected: protected:
@ -227,7 +217,8 @@ protected:
struct BinaryExpr : public Expr { struct BinaryExpr : public Expr {
friend class Expr; friend class Expr;
BinaryExpr(ExprKind kind, Expr lhs, Expr rhs); BinaryExpr(StringRef name, Type result, Expr lhs, Expr rhs,
ArrayRef<NamedAttribute> attrs = {});
Expr getLHS() const; Expr getLHS() const;
Expr getRHS() const; Expr getRHS() const;
@ -239,7 +230,7 @@ protected:
struct TernaryExpr : public Expr { struct TernaryExpr : public Expr {
friend class Expr; friend class Expr;
TernaryExpr(ExprKind kind, Expr cond, Expr lhs, Expr rhs); TernaryExpr(StringRef name, Expr cond, Expr lhs, Expr rhs);
Expr getCond() const; Expr getCond() const;
Expr getLHS() const; Expr getLHS() const;
Expr getRHS() const; Expr getRHS() const;
@ -252,8 +243,9 @@ protected:
struct VariadicExpr : public Expr { struct VariadicExpr : public Expr {
friend class Expr; friend class Expr;
VariadicExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs, VariadicExpr(StringRef name, llvm::ArrayRef<Expr> exprs,
llvm::ArrayRef<Type> types = {}); llvm::ArrayRef<Type> types = {},
ArrayRef<NamedAttribute> attrs = {});
llvm::ArrayRef<Expr> getExprs() const; llvm::ArrayRef<Expr> getExprs() const;
llvm::ArrayRef<Type> getTypes() const; llvm::ArrayRef<Type> getTypes() const;
@ -554,7 +546,7 @@ namespace edsc {
/// `llvm::SmallVector<Expr, 8> dims(n);` directly because a single /// `llvm::SmallVector<Expr, 8> dims(n);` directly because a single
/// `Expr` will be default constructed and copied everywhere in the vector. /// `Expr` will be default constructed and copied everywhere in the vector.
/// Hilarity ensues when trying to bind `Expr` multiple times. /// Hilarity ensues when trying to bind `Expr` multiple times.
llvm::SmallVector<Expr, 8> makeNewExprs(unsigned n); llvm::SmallVector<Expr, 8> makeNewExprs(unsigned n, Type type);
template <typename IterTy> template <typename IterTy>
llvm::SmallVector<Expr, 8> copyExprs(IterTy begin, IterTy end) { llvm::SmallVector<Expr, 8> copyExprs(IterTy begin, IterTy end) {
return llvm::SmallVector<Expr, 8>(begin, end); return llvm::SmallVector<Expr, 8>(begin, end);

View File

@ -44,13 +44,14 @@ char LowerEDSCTestPass::passID = 0;
#include "mlir/EDSC/reference-impl.inc" #include "mlir/EDSC/reference-impl.inc"
PassResult LowerEDSCTestPass::runOnFunction(Function *f) { PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
// Inject a EDSC-constructed list of blocks.
if (f->getName().strref() == "blocks") { if (f->getName().strref() == "blocks") {
using namespace edsc::op; using namespace edsc::op;
FuncBuilder builder(f); FuncBuilder builder(f);
edsc::ScopedEDSCContext context; edsc::ScopedEDSCContext context;
edsc::Expr arg1, arg2, arg3, arg4;
auto type = builder.getIntegerType(32); auto type = builder.getIntegerType(32);
edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type);
auto b1 = auto b1 =
edsc::block({arg1, arg2}, {type, type}, {arg1 + arg2, edsc::Return()}); edsc::block({arg1, arg2}, {type, type}, {arg1 + arg2, edsc::Return()});
@ -73,8 +74,9 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
"dynamic_for expected index arguments"); "dynamic_for expected index arguments");
} }
Type index = IndexType::get(f->getContext());
edsc::ScopedEDSCContext context; edsc::ScopedEDSCContext context;
edsc::Expr lb, ub, step; edsc::Expr lb(index), ub(index), step(index);
auto loop = edsc::For(lb, ub, step, {}); auto loop = edsc::For(lb, ub, step, {});
edsc::MLIREmitter(&builder, f->getLoc()) edsc::MLIREmitter(&builder, f->getLoc())
.bind(edsc::Bindable(lb), f->getArgument(0)) .bind(edsc::Bindable(lb), f->getArgument(0))
@ -83,6 +85,7 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
.emitStmt(loop); .emitStmt(loop);
return success(); return success();
} }
// Inject a EDSC-constructed `for` loop with non-constant bounds that are // Inject a EDSC-constructed `for` loop with non-constant bounds that are
// obtained from AffineApplyOp (also constructed using EDSC operator // obtained from AffineApplyOp (also constructed using EDSC operator
// overloads). // overloads).
@ -97,8 +100,9 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
"dynamic_for expected index arguments"); "dynamic_for expected index arguments");
} }
Type index = IndexType::get(f->getContext());
edsc::ScopedEDSCContext context; edsc::ScopedEDSCContext context;
edsc::Expr lb1, lb2, ub1, ub2, step; edsc::Expr lb1(index), lb2(index), ub1(index), ub2(index), step(index);
using namespace edsc::op; using namespace edsc::op;
auto lb = lb1 - lb2; auto lb = lb1 - lb2;
auto ub = ub1 + ub2; auto ub = ub1 + ub2;

View File

@ -145,7 +145,8 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
} }
mlir::edsc::MLIREmitter::MLIREmitter(FuncBuilder *builder, Location location) mlir::edsc::MLIREmitter::MLIREmitter(FuncBuilder *builder, Location location)
: builder(builder), location(location) { : builder(builder), location(location), zeroIndex(builder->getIndexType()),
oneIndex(builder->getIndexType()) {
// Build the ubiquitous zero and one at the top of the function. // Build the ubiquitous zero and one at the top of the function.
bindConstant<ConstantIndexOp>(Bindable(zeroIndex), 0); bindConstant<ConstantIndexOp>(Bindable(zeroIndex), 0);
bindConstant<ConstantIndexOp>(Bindable(oneIndex), 1); bindConstant<ConstantIndexOp>(Bindable(oneIndex), 1);
@ -166,140 +167,23 @@ MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
} }
Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
// It is still necessary in case we try to emit a bindable directly
// FIXME: make sure isa<Bindable> works and use it below to delegate emission
// to Expr::build and remove this, now duplicate, check.
auto it = ssaBindings.find(e); auto it = ssaBindings.find(e);
if (it != ssaBindings.end()) { if (it != ssaBindings.end()) {
return it->second; return it->second;
} }
// Skip bindables, they must have been found already.
Value *res = nullptr; Value *res = nullptr;
if (auto un = e.dyn_cast<UnaryExpr>()) { bool expectedEmpty = false;
if (un.getKind() == ExprKind::Dealloc) { if (e.isa<UnaryExpr>() || e.isa<BinaryExpr>() || e.isa<TernaryExpr>() ||
builder->create<DeallocOp>(location, emitExpr(un.getExpr())); e.isa<VariadicExpr>()) {
return nullptr; auto results = e.build(*builder, ssaBindings);
} else if (un.getKind() == ExprKind::Negate) { assert(results.size() <= 1 && "2+-result exprs are not supported");
auto ctrue = builder->create<mlir::ConstantIntOp>(location, 1, expectedEmpty = results.empty();
builder->getI1Type()); if (!results.empty())
// TODO(dvytin): worth binding constant in ssaBindings in the future? res = results.front();
// TODO(dvytin): no need to cast getExpr() to I1?
auto val = emitExpr(un.getExpr());
assert(val->getType().isInteger(1) &&
"Logical Negate expects i1 operand");
return sub(builder, location, ctrue, val);
}
} else if (auto bin = e.dyn_cast<BinaryExpr>()) {
auto lhs = bin.getLHS();
auto rhs = bin.getRHS();
auto *a = emitExpr(lhs);
auto *b = emitExpr(rhs);
if (!a || !b) {
return nullptr;
}
if (bin.getKind() == ExprKind::Add) {
res = add(builder, location, a, b);
} else if (bin.getKind() == ExprKind::Sub) {
res = sub(builder, location, a, b);
} else if (bin.getKind() == ExprKind::Mul) {
res = mul(builder, location, a, b);
} else if (bin.getKind() == ExprKind::And) {
// Operands should both be i1
assert(a->getType().isInteger(1) && "Logical And expects i1 LHS");
assert(b->getType().isInteger(1) && "Logical And expects i1 RHS");
res = mul(builder, location, a, b);
} else if (bin.getKind() == ExprKind::Or) {
assert(a->getType().isInteger(1) && "Logical Or expects i1 LHS");
assert(b->getType().isInteger(1) && "Logical Or expects i1 RHS");
// a || b = not (not a && not b)
using edsc::op::operator!;
using edsc::op::operator&&;
res = emitExpr(!(!lhs && !rhs));
} // TODO(ntv): signed vs unsiged ??
// TODO(ntv): integer vs not ??
// TODO(ntv): float cmp
else if (bin.getKind() == ExprKind::EQ) {
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::EQ, a, b);
} else if (bin.getKind() == ExprKind::NE) {
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::NE, a, b);
} else if (bin.getKind() == ExprKind::LT) {
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SLT, a, b);
} else if (bin.getKind() == ExprKind::LE) {
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SLE, a, b);
} else if (bin.getKind() == ExprKind::GT) {
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SGT, a, b);
} else if (bin.getKind() == ExprKind::GE) {
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SGE, a, b);
}
// TODO(ntv): do we want this?
// if (res && ((a->type().is_uint() && !b->type().is_uint()) ||
// (!a->type().is_uint() && b->type().is_uint()))) {
// std::stringstream ss;
// ss << "a: " << *a << "\t b: " << *b;
// res->getDefiningOperation()->emitWarning(
// "Mixing signed and unsigned integers: " + ss.str());
// }
// }
}
if (auto ter = e.dyn_cast<TernaryExpr>()) {
if (ter.getKind() == ExprKind::Select) {
auto *cond = emitExpr(ter.getCond());
auto *lhs = emitExpr(ter.getLHS());
auto *rhs = emitExpr(ter.getRHS());
if (!cond || !rhs || !lhs) {
return nullptr;
}
res = builder->create<SelectOp>(location, cond, lhs, rhs)->getResult();
}
}
if (auto nar = e.dyn_cast<VariadicExpr>()) {
if (nar.getKind() == ExprKind::Alloc) {
auto exprs = emitExprs(nar.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
auto types = nar.getTypes();
assert(types.size() == 1 && "Expected 1 type");
res =
builder->create<AllocOp>(location, types[0].cast<MemRefType>(), exprs)
->getResult();
} else if (nar.getKind() == ExprKind::Load) {
auto exprs = emitExprs(nar.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
assert(!exprs.empty() && "Load requires >= 1 exprs");
assert(nar.getTypes().empty() && "Load expects no type");
SmallVector<Value *, 8> vals(exprs.begin() + 1, exprs.end());
res = builder->create<LoadOp>(location, exprs[0], vals)->getResult();
} else if (nar.getKind() == ExprKind::Store) {
auto exprs = emitExprs(nar.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
assert(exprs.size() >= 2 && "Store requires >= 2 exprs");
assert(nar.getTypes().empty() && "Store expects no type");
SmallVector<Value *, 8> vals(exprs.begin() + 2, exprs.end());
builder->create<StoreOp>(location, exprs[0], exprs[1], vals);
return nullptr;
} else if (nar.getKind() == ExprKind::VectorTypeCast) {
auto exprs = emitExprs(nar.getExprs());
if (llvm::any_of(exprs, [](Value *v) { return !v; })) {
return nullptr;
}
assert(exprs.size() == 1 && "Expected 1 expr");
auto types = nar.getTypes();
assert(types.size() == 1 && "Expected 1 type");
res = builder
->create<VectorTypeCastOp>(location, exprs[0],
types[0].cast<MemRefType>())
->getResult();
} else if (nar.getKind() == ExprKind::Return) {
auto exprs = emitExprs(nar.getExprs());
builder->create<ReturnOp>(location, exprs);
return nullptr; // no Value* produced and this is fine.
}
} }
if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) { if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) {
@ -349,7 +233,7 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
} }
} }
if (!res) { if (!res && !expectedEmpty) {
// If we hit here it must mean that the Bindables have not all been bound // If we hit here it must mean that the Bindables have not all been bound
// properly. Because EDSCs are currently dynamically typed, it becomes a // properly. Because EDSCs are currently dynamically typed, it becomes a
// runtime error. // runtime error.
@ -386,9 +270,9 @@ void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
auto ip = builder->getInsertionPoint(); auto ip = builder->getInsertionPoint();
auto *val = emitExpr(stmt.getRHS()); auto *val = emitExpr(stmt.getRHS());
if (!val) { if (!val) {
assert((stmt.getRHS().getKind() == ExprKind::Dealloc || assert((stmt.getRHS().getName() == DeallocOp::getOperationName() ||
stmt.getRHS().getKind() == ExprKind::Store || stmt.getRHS().getName() == StoreOp::getOperationName() ||
stmt.getRHS().getKind() == ExprKind::Return) && stmt.getRHS().getName() == ReturnOp::getOperationName()) &&
"dealloc, store or return expected as the only 0-result ops"); "dealloc, store or return expected as the only 0-result ops");
return; return;
} }
@ -491,7 +375,7 @@ mlir::edsc::MLIREmitter::makeBoundFunctionArguments(mlir::Function *function) {
for (unsigned pos = 0, npos = function->getNumArguments(); pos < npos; for (unsigned pos = 0, npos = function->getNumArguments(); pos < npos;
++pos) { ++pos) {
auto *arg = function->getArgument(pos); auto *arg = function->getArgument(pos);
Expr b; Expr b(arg->getType());
bind(Bindable(b), arg); bind(Bindable(b), arg);
res.push_back(Expr(b)); res.push_back(Expr(b));
} }
@ -502,7 +386,8 @@ SmallVector<edsc::Expr, 8>
mlir::edsc::MLIREmitter::makeBoundMemRefShape(Value *memRef) { mlir::edsc::MLIREmitter::makeBoundMemRefShape(Value *memRef) {
assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value"); assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
MemRefType memRefType = memRef->getType().cast<MemRefType>(); MemRefType memRefType = memRef->getType().cast<MemRefType>();
auto memRefSizes = edsc::makeNewExprs(memRefType.getShape().size()); auto memRefSizes =
edsc::makeNewExprs(memRefType.getShape().size(), builder->getIndexType());
auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef); auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef);
assert(memrefSizeValues.size() == memRefSizes.size()); assert(memrefSizeValues.size() == memRefSizes.size());
bindZipRange(llvm::zip(memRefSizes, memrefSizeValues)); bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
@ -517,7 +402,7 @@ mlir::edsc::MLIREmitter::makeBoundMemRefView(Value *memRef) {
SmallVector<edsc::Expr, 8> lbs; SmallVector<edsc::Expr, 8> lbs;
lbs.reserve(rank); lbs.reserve(rank);
Expr zero; Expr zero(builder->getIndexType());
bindConstant<mlir::ConstantIndexOp>(Bindable(zero), 0); bindConstant<mlir::ConstantIndexOp>(Bindable(zero), 0);
for (unsigned i = 0; i < rank; ++i) { for (unsigned i = 0; i < rank; ++i) {
lbs.push_back(zero); lbs.push_back(zero);
@ -527,7 +412,7 @@ mlir::edsc::MLIREmitter::makeBoundMemRefView(Value *memRef) {
SmallVector<edsc::Expr, 8> steps; SmallVector<edsc::Expr, 8> steps;
lbs.reserve(rank); lbs.reserve(rank);
Expr one; Expr one(builder->getIndexType());
bindConstant<mlir::ConstantIndexOp>(Bindable(one), 1); bindConstant<mlir::ConstantIndexOp>(Bindable(one), 1);
for (unsigned i = 0; i < rank; ++i) { for (unsigned i = 0; i < rank; ++i) {
steps.push_back(one); steps.push_back(one);
@ -545,7 +430,7 @@ mlir::edsc::MLIREmitter::makeBoundMemRefView(Expr boundMemRef) {
edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) { edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter); auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Expr b; Expr b(e->getBuilder()->getBF16Type());
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value), e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
e->getBuilder()->getBF16Type()); e->getBuilder()->getBF16Type());
return b; return b;
@ -553,7 +438,7 @@ edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) {
edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) { edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter); auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Expr b; Expr b(e->getBuilder()->getBF16Type());
bool unused; bool unused;
mlir::APFloat val(value); mlir::APFloat val(value);
val.convert(e->getBuilder()->getF16Type().getFloatSemantics(), val.convert(e->getBuilder()->getF16Type().getFloatSemantics(),
@ -565,7 +450,7 @@ edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) {
edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) { edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter); auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Expr b; Expr b(e->getBuilder()->getF32Type());
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value), e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
e->getBuilder()->getF32Type()); e->getBuilder()->getF32Type());
return b; return b;
@ -573,7 +458,7 @@ edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) {
edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) { edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter); auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Expr b; Expr b(e->getBuilder()->getF64Type());
e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value), e->bindConstant<mlir::ConstantFloatOp>(Bindable(b), mlir::APFloat(value),
e->getBuilder()->getF64Type()); e->getBuilder()->getF64Type());
return b; return b;
@ -582,7 +467,7 @@ edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value, edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
unsigned bitwidth) { unsigned bitwidth) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter); auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Expr b; Expr b(e->getBuilder()->getIntegerType(bitwidth));
e->bindConstant<mlir::ConstantIntOp>( e->bindConstant<mlir::ConstantIntOp>(
b, value, e->getBuilder()->getIntegerType(bitwidth)); b, value, e->getBuilder()->getIntegerType(bitwidth));
return b; return b;
@ -590,7 +475,7 @@ edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value) { edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter); auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Expr b; Expr b(e->getBuilder()->getIndexType());
e->bindConstant<mlir::ConstantIndexOp>(Bindable(b), value); e->bindConstant<mlir::ConstantIndexOp>(Bindable(b), value);
return b; return b;
} }
@ -618,7 +503,7 @@ edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
auto *f = reinterpret_cast<mlir::Function *>(function); auto *f = reinterpret_cast<mlir::Function *>(function);
assert(pos < f->getNumArguments()); assert(pos < f->getNumArguments());
auto *arg = *(f->getArguments().begin() + pos); auto *arg = *(f->getArguments().begin() + pos);
Expr b; Expr b(arg->getType());
e->bind(Bindable(b), arg); e->bind(Bindable(b), arg);
return Expr(b); return Expr(b);
} }
@ -630,7 +515,7 @@ void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
assert(result->n == f->getNumArguments()); assert(result->n == f->getNumArguments());
for (unsigned pos = 0; pos < result->n; ++pos) { for (unsigned pos = 0; pos < result->n; ++pos) {
auto *arg = *(f->getArguments().begin() + pos); auto *arg = *(f->getArguments().begin() + pos);
Expr b; Expr b(arg->getType());
e->bind(Bindable(b), arg); e->bind(Bindable(b), arg);
result->exprs[pos] = Expr(b); result->exprs[pos] = Expr(b);
} }
@ -670,9 +555,9 @@ void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
assert(resultUbs->n == rank && "Unexpected memref binding results count"); assert(resultUbs->n == rank && "Unexpected memref binding results count");
assert(resultSteps->n == rank && "Unexpected memref binding results count"); assert(resultSteps->n == rank && "Unexpected memref binding results count");
auto bindables = e->makeBoundMemRefShape(v); auto bindables = e->makeBoundMemRefShape(v);
Expr zero; Expr zero(e->getBuilder()->getIndexType());
e->bindConstant<mlir::ConstantIndexOp>(zero, 0); e->bindConstant<mlir::ConstantIndexOp>(zero, 0);
Expr one; Expr one(e->getBuilder()->getIndexType());
e->bindConstant<mlir::ConstantIndexOp>(one, 1); e->bindConstant<mlir::ConstantIndexOp>(one, 1);
for (unsigned i = 0; i < rank; ++i) { for (unsigned i = 0; i < rank; ++i) {
resultLbs->exprs[i] = zero; resultLbs->exprs[i] = zero;

View File

@ -17,9 +17,16 @@
#include "mlir/EDSC/Types.h" #include "mlir/EDSC/Types.h"
#include "mlir-c/Core.h" #include "mlir-c/Core.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
@ -54,16 +61,23 @@ struct ExprStorage {
ExprKind kind; ExprKind kind;
unsigned id; unsigned id;
StringRef opName;
ArrayRef<Expr> operands; ArrayRef<Expr> operands;
ArrayRef<Type> resultTypes; ArrayRef<Type> resultTypes;
ArrayRef<NamedAttribute> attributes; ArrayRef<NamedAttribute> attributes;
ExprStorage(ExprKind kind, ArrayRef<Type> results, ArrayRef<Expr> children, ExprStorage(ExprKind kind, StringRef name, ArrayRef<Type> results,
ArrayRef<NamedAttribute> attrs, unsigned exprId = Expr::newId()) ArrayRef<Expr> children, ArrayRef<NamedAttribute> attrs,
StringRef descr = "", unsigned exprId = Expr::newId())
: kind(kind), id(exprId) { : kind(kind), id(exprId) {
operands = copyIntoExprAllocator(children); operands = copyIntoExprAllocator(children);
resultTypes = copyIntoExprAllocator(results); resultTypes = copyIntoExprAllocator(results);
attributes = copyIntoExprAllocator(attrs); attributes = copyIntoExprAllocator(attrs);
if (!name.empty()) {
auto nameStorage = Expr::globalAllocator()->Allocate<char>(name.size());
std::uninitialized_copy(name.begin(), name.end(), nameStorage);
opName = StringRef(nameStorage, name.size());
}
} }
}; };
@ -101,10 +115,10 @@ mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
Expr::globalAllocator() = nullptr; Expr::globalAllocator() = nullptr;
} }
mlir::edsc::Expr::Expr() { mlir::edsc::Expr::Expr(Type type) {
// Initialize with placement new. // Initialize with placement new.
storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>(); storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
new (storage) detail::ExprStorage(ExprKind::Unbound, {}, {}, {}); new (storage) detail::ExprStorage(ExprKind::Unbound, "", {type}, {}, {});
} }
ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; } ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; }
@ -118,49 +132,173 @@ unsigned &mlir::edsc::Expr::newId() {
return ++id; return ++id;
} }
ArrayRef<Type> mlir::edsc::Expr::getResultTypes() const {
return storage->resultTypes;
}
ArrayRef<Expr> mlir::edsc::Expr::getChildExpressions() const {
return storage->operands;
}
ArrayRef<NamedAttribute> mlir::edsc::Expr::getAttributes() const {
return storage->attributes;
}
StringRef mlir::edsc::Expr::getName() const {
return static_cast<ImplType *>(storage)->opName;
}
SmallVector<Value *, 4>
Expr::build(FuncBuilder &b,
const llvm::DenseMap<Expr, Value *> &ssaBindings) const {
auto it = ssaBindings.find(*this);
if (it != ssaBindings.end())
return {it->second};
auto *impl = static_cast<ImplType *>(storage);
auto state = OperationState(b.getContext(), b.getUnknownLoc(), impl->opName);
SmallVector<Value *, 4> operandValues;
operandValues.reserve(impl->operands.size());
for (auto child : impl->operands) {
auto subResults = child.build(b, ssaBindings);
assert(subResults.size() == 1 &&
"expected single-result expression as operand");
operandValues.push_back(subResults.front());
}
state.addOperands(operandValues);
state.addTypes(impl->resultTypes);
for (const auto &attr : impl->attributes)
state.addAttribute(attr.first, attr.second);
Instruction *inst = b.createOperation(state);
return llvm::to_vector<4>(inst->getResults());
}
static Expr createBinaryExpr(
Expr lhs, Expr rhs, StringRef intOpName, StringRef floatOpName,
std::function<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
"only single-result exprs are supported in operators");
auto thisType = lhs.getResultTypes().front();
auto thatType = rhs.getResultTypes().front();
assert(thisType == thatType && "cannot mix types in operators");
StringRef opName;
if (thisType.isIndex()) {
MLIRContext *context = thisType.getContext();
auto d0 = getAffineDimExpr(0, context);
auto d1 = getAffineDimExpr(1, context);
auto map = AffineMap::get(2, 0, {affCombiner(d0, d1)}, {});
auto attr = AffineMapAttr::get(map);
auto attrId = Identifier::get("map", context);
auto namedAttr = NamedAttribute{attrId, attr};
return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)},
{namedAttr});
} else if (thisType.isa<IntegerType>()) {
opName = intOpName;
} else if (thisType.isa<FloatType>()) {
opName = floatOpName;
} else if (auto aggregateType = thisType.dyn_cast<VectorOrTensorType>()) {
if (aggregateType.getElementType().isa<IntegerType>())
opName = intOpName;
else if (aggregateType.getElementType().isa<FloatType>())
opName = floatOpName;
}
if (!opName.empty())
return BinaryExpr(opName, thisType, lhs, rhs);
llvm_unreachable("failed to create an Expr");
}
Expr mlir::edsc::op::operator+(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator+(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::Add, lhs, rhs); return createBinaryExpr(lhs, rhs, "addi", "addf",
[](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
} }
Expr mlir::edsc::op::operator-(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator-(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::Sub, lhs, rhs); return createBinaryExpr(lhs, rhs, "subi", "subf",
[](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
} }
Expr mlir::edsc::op::operator*(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator*(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::Mul, lhs, rhs); return createBinaryExpr(lhs, rhs, "muli", "mulf",
[](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
}
static Expr createComparisonExpr(CmpIPredicate predicate, Expr lhs, Expr rhs) {
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
"only single-result exprs are supported in operators");
auto lhsType = lhs.getResultTypes().front();
auto rhsType = rhs.getResultTypes().front();
assert(lhsType == rhsType && "cannot mix types in operators");
assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) &&
"only integer comparisons are supported");
MLIRContext *context = lhsType.getContext();
auto attr = IntegerAttr::get(IndexType::get(context),
static_cast<int64_t>(predicate));
auto attrId = Identifier::get("predicate", context);
auto namedAttr = NamedAttribute{attrId, attr};
return BinaryExpr("cmpi", IntegerType::get(1, context), lhs, rhs,
{namedAttr});
} }
Expr mlir::edsc::op::operator==(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator==(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::EQ, lhs, rhs); return createComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
} }
Expr mlir::edsc::op::operator!=(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator!=(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::NE, lhs, rhs); return createComparisonExpr(CmpIPredicate::NE, lhs, rhs);
} }
Expr mlir::edsc::op::operator<(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator<(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::LT, lhs, rhs); // TODO(ntv,zinenko): signed by default, how about unsigned?
return createComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
} }
Expr mlir::edsc::op::operator<=(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator<=(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::LE, lhs, rhs); return createComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
} }
Expr mlir::edsc::op::operator>(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator>(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::GT, lhs, rhs); return createComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
} }
Expr mlir::edsc::op::operator>=(Expr lhs, Expr rhs) { Expr mlir::edsc::op::operator>=(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::GE, lhs, rhs); return createComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
}
Expr mlir::edsc::op::operator&&(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::And, lhs, rhs);
}
Expr mlir::edsc::op::operator||(Expr lhs, Expr rhs) {
return BinaryExpr(ExprKind::Or, lhs, rhs);
}
Expr mlir::edsc::op::operator!(Expr expr) {
return UnaryExpr(ExprKind::Negate, expr);
} }
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n) { Expr mlir::edsc::op::operator&&(Expr lhs, Expr rhs) {
assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
"expected single-result exprs");
auto thisType = lhs.getResultTypes().front();
auto thatType = rhs.getResultTypes().front();
assert(thisType.isInteger(1) && thatType.isInteger(1) &&
"logical And expects i1");
return BinaryExpr("muli", thisType, lhs, rhs);
}
Expr mlir::edsc::op::operator||(Expr lhs, Expr rhs) {
// There is not support for bitwise operations, so we emulate logical 'or'
// lhs || rhs
// as
// !(!lhs && !rhs).
using namespace edsc::op;
return !(!lhs && !rhs);
}
Expr mlir::edsc::op::operator!(Expr expr) {
assert(expr.getResultTypes().size() == 1 && "expected single-result exprs");
auto thisType = expr.getResultTypes().front();
assert(thisType.isInteger(1) && "logical Not expects i1");
MLIRContext *context = thisType.getContext();
// Create constant 1 expression.s
auto attr = IntegerAttr::get(thisType, 1);
auto attrId = Identifier::get("value", context);
auto namedAttr = NamedAttribute{attrId, attr};
auto cstOne = VariadicExpr("constant", {}, thisType, {namedAttr});
// Emulate negation as (1 - x) : i1
return cstOne - expr;
}
llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n, Type type) {
llvm::SmallVector<Expr, 8> res; llvm::SmallVector<Expr, 8> res;
res.reserve(n); res.reserve(n);
for (auto i = 0; i < n; ++i) { for (auto i = 0; i < n; ++i) {
res.push_back(Expr()); res.push_back(Expr(type));
} }
return res; return res;
} }
@ -183,15 +321,15 @@ static void fillStmts(edsc_stmt_list_t enclosedStmts,
} }
Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) { Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
return VariadicExpr(ExprKind::Alloc, sizes, memrefType); return VariadicExpr("alloc", sizes, memrefType);
} }
Expr mlir::edsc::dealloc(Expr memref) { Expr mlir::edsc::dealloc(Expr memref) { return UnaryExpr("dealloc", memref); }
return UnaryExpr(ExprKind::Dealloc, memref);
}
Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) { Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
Expr idx; assert(lb.getResultTypes().size() == 1 && "expected single-result bounds");
auto type = lb.getResultTypes().front();
Expr idx(type);
return For(Bindable(idx), lb, ub, step, stmts); return For(Bindable(idx), lb, ub, step, stmts);
} }
@ -249,10 +387,14 @@ edsc_block_t Block(edsc_stmt_list_t enclosedStmts) {
} }
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) { Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
assert(m.getResultTypes().size() == 1 && "expected single-result expr");
auto type = m.getResultTypes().front().dyn_cast<MemRefType>();
assert(type && "expected memref type");
SmallVector<Expr, 8> exprs; SmallVector<Expr, 8> exprs;
exprs.push_back(m); exprs.push_back(m);
exprs.append(indices.begin(), indices.end()); exprs.append(indices.begin(), indices.end());
return VariadicExpr(ExprKind::Load, exprs); return VariadicExpr("load", exprs, {type.getElementType()});
} }
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) { edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) {
@ -267,7 +409,7 @@ Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
exprs.push_back(val); exprs.push_back(val);
exprs.push_back(m); exprs.push_back(m);
exprs.append(indices.begin(), indices.end()); exprs.append(indices.begin(), indices.end());
return VariadicExpr(ExprKind::Store, exprs); return VariadicExpr("store", exprs);
} }
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed, edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
@ -279,7 +421,7 @@ edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
} }
Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) { Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) {
return TernaryExpr(ExprKind::Select, cond, lhs, rhs); return TernaryExpr("select", cond, lhs, rhs);
} }
edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) { edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) {
@ -287,106 +429,210 @@ edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) {
} }
Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) { Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) {
return VariadicExpr(ExprKind::VectorTypeCast, {memrefExpr}, {memrefType}); return VariadicExpr("vector_type_cast", {memrefExpr}, {memrefType});
} }
Stmt mlir::edsc::Return(ArrayRef<Expr> values) { Stmt mlir::edsc::Return(ArrayRef<Expr> values) {
return VariadicExpr(ExprKind::Return, values); return VariadicExpr("return", values);
} }
edsc_stmt_t Return(edsc_expr_list_t values) { edsc_stmt_t Return(edsc_expr_list_t values) {
return Stmt(Return(makeExprs(values))); return Stmt(Return(makeExprs(values)));
} }
static raw_ostream &printBinaryExpr(raw_ostream &os, BinaryExpr e,
StringRef infix) {
os << '(' << e.getLHS() << ' ' << infix << ' ' << e.getRHS() << ')';
return os;
}
// Get the operator spelling for pretty-printing the infix form of a
// comparison operator.
static StringRef getCmpIPredicateInfix(const mlir::edsc::Expr &e) {
Attribute predicate;
for (const auto &namedAttr : e.getAttributes()) {
if (namedAttr.first.is(CmpIOp::getPredicateAttrName())) {
predicate = namedAttr.second;
break;
}
}
assert(predicate && "expected a predicate in a comparison expr");
switch (static_cast<CmpIPredicate>(
predicate.cast<IntegerAttr>().getValue().getSExtValue())) {
case CmpIPredicate::EQ:
return "==";
case CmpIPredicate::NE:
return "!=";
case CmpIPredicate::SGT:
case CmpIPredicate::UGT:
return ">";
case CmpIPredicate::SLT:
case CmpIPredicate::ULT:
return "<";
case CmpIPredicate::SGE:
case CmpIPredicate::UGE:
return ">=";
case CmpIPredicate::SLE:
case CmpIPredicate::ULE:
return "<=";
default:
llvm_unreachable("unknown predicate");
}
return "";
}
static void printAffineExpr(raw_ostream &os, AffineExpr expr,
ArrayRef<Expr> dims, ArrayRef<Expr> symbols) {
struct Visitor : public AffineExprVisitor<Visitor> {
Visitor(raw_ostream &ostream, ArrayRef<Expr> dimExprs,
ArrayRef<Expr> symExprs)
: os(ostream), dims(dimExprs), symbols(symExprs) {}
raw_ostream &os;
ArrayRef<Expr> dims;
ArrayRef<Expr> symbols;
void visitDimExpr(AffineDimExpr dimExpr) {
os << dims[dimExpr.getPosition()];
}
void visitSymbolExpr(AffineSymbolExpr symbolExpr) {
os << symbols[symbolExpr.getPosition()];
}
void visitConstantExpr(AffineConstantExpr constExpr) {
os << constExpr.getValue();
}
void visitBinaryExpr(AffineBinaryOpExpr expr, StringRef infix) {
visit(expr.getLHS());
os << infix;
visit(expr.getRHS());
}
void visitAddExpr(AffineBinaryOpExpr binOp) {
visitBinaryExpr(binOp, " + ");
}
void visitMulExpr(AffineBinaryOpExpr binOp) {
visitBinaryExpr(binOp, " * ");
}
void visitModExpr(AffineBinaryOpExpr binOp) {
visitBinaryExpr(binOp, " % ");
}
void visitCeilDivExpr(AffineBinaryOpExpr binOp) {
visitBinaryExpr(binOp, " ceildiv ");
}
void visitFloorDivExpr(AffineBinaryOpExpr binOp) {
visitBinaryExpr(binOp, " floordiv ");
}
};
Visitor(os, dims, symbols).visit(expr);
}
static void printAffineMap(raw_ostream &os, AffineMap map,
ArrayRef<Expr> operands) {
auto dims = operands.take_front(map.getNumDims());
auto symbols = operands.drop_front(map.getNumDims());
assert(map.getNumResults() == 1 &&
"only 1-result maps are currently supported");
printAffineExpr(os, map.getResult(0), dims, symbols);
}
void printAffineApply(raw_ostream &os, mlir::edsc::Expr e) {
Attribute mapAttr;
for (const auto &namedAttr : e.getAttributes()) {
if (namedAttr.first.is("map")) {
mapAttr = namedAttr.second;
break;
}
}
assert(mapAttr && "expected a map in an affine apply expr");
printAffineMap(os, mapAttr.cast<AffineMapAttr>().getValue(),
e.getChildExpressions());
}
void mlir::edsc::Expr::print(raw_ostream &os) const { void mlir::edsc::Expr::print(raw_ostream &os) const {
if (auto unbound = this->dyn_cast<Bindable>()) { if (auto unbound = this->dyn_cast<Bindable>()) {
os << "$" << unbound.getId(); os << "$" << unbound.getId();
return; return;
} else if (auto un = this->dyn_cast<UnaryExpr>()) { }
switch (un.getKind()) {
case ExprKind::Negate: // Handle known binary ops with pretty infix forms.
os << "~"; if (auto binExpr = this->dyn_cast<BinaryExpr>()) {
break; StringRef name = getName();
default: { StringRef infix;
os << "unknown_unary"; if (name == AddIOp::getOperationName() ||
name == AddFOp::getOperationName())
infix = "+";
else if (name == SubIOp::getOperationName() ||
name == SubFOp::getOperationName())
infix = "-";
else if (name == MulIOp::getOperationName() ||
name == MulFOp::getOperationName())
infix = binExpr.getResultTypes().front().isInteger(1) ? "&&" : "*";
else if (name == DivIUOp::getOperationName() ||
name == DivISOp::getOperationName())
infix = "/";
else if (name == RemIUOp::getOperationName() ||
name == RemISOp::getOperationName())
infix = "%";
else if (name == CmpIOp::getOperationName())
infix = getCmpIPredicateInfix(*this);
if (!infix.empty()) {
printBinaryExpr(os, binExpr, infix);
return;
} }
}
// Handle known variadic ops with pretty forms.
if (auto narExpr = this->dyn_cast<VariadicExpr>()) {
StringRef name = getName();
if (name == LoadOp::getOperationName()) {
os << name << '(' << getChildExpressions().front() << '[';
interleaveComma(getChildExpressions().drop_front(), os);
os << "])";
return;
} }
os << un.getExpr(); if (name == StoreOp::getOperationName()) {
} else if (auto bin = this->dyn_cast<BinaryExpr>()) { os << name << '(' << getChildExpressions().front() << ", "
os << "(" << bin.getLHS(); << getChildExpressions()[1] << '[';
switch (bin.getKind()) { interleaveComma(getChildExpressions().drop_front(2), os);
case ExprKind::Add: os << "])";
os << " + "; return;
break;
case ExprKind::Sub:
os << " - ";
break;
case ExprKind::Mul:
os << " * ";
break;
case ExprKind::Div:
os << " / ";
break;
case ExprKind::LT:
os << " < ";
break;
case ExprKind::LE:
os << " <= ";
break;
case ExprKind::GT:
os << " > ";
break;
case ExprKind::GE:
os << " >= ";
break;
case ExprKind::EQ:
os << " == ";
break;
case ExprKind::NE:
os << " != ";
break;
case ExprKind::And:
os << " && ";
break;
case ExprKind::Or:
os << " || ";
break;
default: {
os << "unknown_binary";
} }
if (name == AffineApplyOp::getOperationName()) {
os << '(';
printAffineApply(os, *this);
os << ')';
return;
} }
os << bin.getRHS() << ")"; }
// Handle all other types of ops with a more generic printing form.
if (this->isa<UnaryExpr>() || this->isa<BinaryExpr>() ||
this->isa<TernaryExpr>() || this->isa<VariadicExpr>()) {
os << (getName().empty() ? "##unknown##" : getName()) << '(';
interleaveComma(getChildExpressions(), os);
auto attrs = getAttributes();
if (!attrs.empty()) {
os << '{';
interleave(
attrs,
[&os](const NamedAttribute &attr) {
os << attr.first.strref() << ": " << attr.second;
},
[&os]() { os << ", "; });
os << '}';
}
os << ')';
return; return;
} else if (auto ter = this->dyn_cast<TernaryExpr>()) {
switch (ter.getKind()) {
case ExprKind::Select:
os << "select(" << ter.getCond() << ", " << ter.getLHS() << ", "
<< ter.getRHS() << ")";
return;
default: {
os << "unknown_ternary";
}
}
} else if (auto nar = this->dyn_cast<VariadicExpr>()) {
auto exprs = nar.getExprs();
switch (nar.getKind()) {
case ExprKind::Load:
os << "load(" << exprs[0] << "[";
interleaveComma(ArrayRef<Expr>(exprs.begin() + 1, exprs.size() - 1), os);
os << "])";
return;
case ExprKind::Store:
os << "store(" << exprs[0] << ", " << exprs[1] << "[";
interleaveComma(ArrayRef<Expr>(exprs.begin() + 2, exprs.size() - 2), os);
os << "])";
return;
case ExprKind::Return:
interleaveComma(exprs, os);
return;
default: {
os << "unknown_variadic";
}
}
} else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) { } else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) {
auto exprs = stmtLikeExpr.getExprs(); auto exprs = stmtLikeExpr.getExprs();
switch (stmtLikeExpr.getKind()) { switch (stmtLikeExpr.getKind()) {
@ -419,21 +665,26 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
return os; return os;
} }
edsc_expr_t makeBindable() { return Bindable(Expr()); } edsc_expr_t makeBindable(mlir_type_t type) {
return Bindable(Expr(Type(reinterpret_cast<const Type::ImplType *>(type))));
}
mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr) mlir::edsc::UnaryExpr::UnaryExpr(StringRef name, Expr expr)
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
// Initialize with placement new. // Initialize with placement new.
new (storage) detail::ExprStorage(kind, {}, {expr}, {}); new (storage)
detail::ExprStorage(ExprKind::FIRST_UNARY_EXPR, name, {}, {expr}, {});
} }
Expr mlir::edsc::UnaryExpr::getExpr() const { Expr mlir::edsc::UnaryExpr::getExpr() const {
return static_cast<ImplType *>(storage)->operands.front(); return static_cast<ImplType *>(storage)->operands.front();
} }
mlir::edsc::BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs) mlir::edsc::BinaryExpr::BinaryExpr(StringRef name, Type result, Expr lhs,
Expr rhs, ArrayRef<NamedAttribute> attrs)
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
// Initialize with placement new. // Initialize with placement new.
new (storage) detail::ExprStorage(kind, {}, {lhs, rhs}, {}); new (storage) detail::ExprStorage(ExprKind::FIRST_BINARY_EXPR, name, {result},
{lhs, rhs}, attrs);
} }
Expr mlir::edsc::BinaryExpr::getLHS() const { Expr mlir::edsc::BinaryExpr::getLHS() const {
return static_cast<ImplType *>(storage)->operands.front(); return static_cast<ImplType *>(storage)->operands.front();
@ -442,11 +693,15 @@ Expr mlir::edsc::BinaryExpr::getRHS() const {
return static_cast<ImplType *>(storage)->operands.back(); return static_cast<ImplType *>(storage)->operands.back();
} }
mlir::edsc::TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs, mlir::edsc::TernaryExpr::TernaryExpr(StringRef name, Expr cond, Expr lhs,
Expr rhs) Expr rhs)
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
// Initialize with placement new. // Initialize with placement new.
new (storage) detail::ExprStorage(kind, {}, {cond, lhs, rhs}, {}); assert(lhs.getResultTypes().size() == 1 && "expected single-result expr");
assert(rhs.getResultTypes().size() == 1 && "expected single-result expr");
new (storage)
detail::ExprStorage(ExprKind::FIRST_TERNARY_EXPR, name,
{lhs.getResultTypes().front()}, {cond, lhs, rhs}, {});
} }
Expr mlir::edsc::TernaryExpr::getCond() const { Expr mlir::edsc::TernaryExpr::getCond() const {
return static_cast<ImplType *>(storage)->operands[0]; return static_cast<ImplType *>(storage)->operands[0];
@ -458,11 +713,13 @@ Expr mlir::edsc::TernaryExpr::getRHS() const {
return static_cast<ImplType *>(storage)->operands[2]; return static_cast<ImplType *>(storage)->operands[2];
} }
mlir::edsc::VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs, mlir::edsc::VariadicExpr::VariadicExpr(StringRef name, ArrayRef<Expr> exprs,
ArrayRef<Type> types) ArrayRef<Type> types,
ArrayRef<NamedAttribute> attrs)
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
// Initialize with placement new. // Initialize with placement new.
new (storage) detail::ExprStorage(kind, types, exprs, {}); new (storage) detail::ExprStorage(ExprKind::FIRST_VARIADIC_EXPR, name, types,
exprs, attrs);
} }
ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const { ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const {
return static_cast<ImplType *>(storage)->operands; return static_cast<ImplType *>(storage)->operands;
@ -476,7 +733,7 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
ArrayRef<Type> types) ArrayRef<Type> types)
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) { : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
// Initialize with placement new. // Initialize with placement new.
new (storage) detail::ExprStorage(kind, types, exprs, {}); new (storage) detail::ExprStorage(kind, "", types, exprs, {});
} }
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const { ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
return static_cast<ImplType *>(storage)->operands; return static_cast<ImplType *>(storage)->operands;
@ -494,8 +751,9 @@ mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())}; lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())};
} }
// Statement with enclosed statements does not have a LHS.
mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts) mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
: Stmt(Bindable(Expr()), rhs, enclosedStmts) {} : Stmt(Bindable(Expr(Type())), rhs, enclosedStmts) {}
edsc_stmt_t makeStmt(edsc_expr_t e) { edsc_stmt_t makeStmt(edsc_expr_t e) {
assert(e && "unexpected empty expression"); assert(e && "unexpected empty expression");
@ -503,7 +761,7 @@ edsc_stmt_t makeStmt(edsc_expr_t e) {
} }
Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) { Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
Stmt res(Bindable(Expr()), expr, {}); Stmt res(Bindable(Expr(Type())), expr, {});
std::swap(res.storage, this->storage); std::swap(res.storage, this->storage);
return *this; return *this;
} }
@ -678,6 +936,12 @@ mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
return mlir_type_t{ft.getAsOpaquePointer()}; return mlir_type_t{ft.getAsOpaquePointer()};
} }
mlir_type_t makeIndexType(mlir_context_t context) {
auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
auto type = mlir::IndexType::get(ctx);
return mlir_type_t{type.getAsOpaquePointer()};
}
unsigned getFunctionArity(mlir_func_t function) { unsigned getFunctionArity(mlir_func_t function) {
auto *f = reinterpret_cast<mlir::Function *>(function); auto *f = reinterpret_cast<mlir::Function *>(function);
return f->getNumArguments(); return f->getNumArguments();

View File

@ -167,7 +167,7 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
// Create new Exprs for ivs, they will be bound at `For` Stmt // Create new Exprs for ivs, they will be bound at `For` Stmt
// construction. // construction.
auto ivs = makeNewExprs(vectorShape.size()); auto ivs = makeNewExprs(vectorShape.size(), this->rewriter->getIndexType());
// Create and bind Exprs to refer to the Value for memref sizes. // Create and bind Exprs to refer to the Value for memref sizes.
auto memRefSizes = emitter.makeBoundMemRefShape(transfer->getMemRef()); auto memRefSizes = emitter.makeBoundMemRefShape(transfer->getMemRef());
@ -222,9 +222,9 @@ VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
// Create the proper bindables for lbs, ubs and steps. Additionally, if we // Create the proper bindables for lbs, ubs and steps. Additionally, if we
// recorded a coalescing index, permute the loop informations. // recorded a coalescing index, permute the loop informations.
auto lbs = makeNewExprs(ivs.size()); auto lbs = makeNewExprs(ivs.size(), this->rewriter->getIndexType());
auto ubs = copyExprs(vectorSizes); auto ubs = copyExprs(vectorSizes);
auto steps = makeNewExprs(ivs.size()); auto steps = makeNewExprs(ivs.size(), this->rewriter->getIndexType());
if (coalescingIndex >= 0) { if (coalescingIndex >= 0) {
std::swap(ivs[coalescingIndex], ivs.back()); std::swap(ivs[coalescingIndex], ivs.back());
std::swap(lbs[coalescingIndex], lbs.back()); std::swap(lbs[coalescingIndex], lbs.back());
@ -257,11 +257,14 @@ VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0)), MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0)),
vectorMemRefType(MemRefType::get({1}, vectorType, {}, 0)), vectorMemRefType(MemRefType::get({1}, vectorType, {}, 0)),
emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())), emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())),
vectorSizes(edsc::makeNewExprs(vectorShape.size())), zero(emitter.zero()), vectorSizes(
one(emitter.one()) { edsc::makeNewExprs(vectorShape.size(), rewriter->getIndexType())),
zero(emitter.zero()), one(emitter.one()),
scalarMemRef(transfer->getMemRefType()) {
// Bind the Bindable. // Bind the Bindable.
SmallVector<Value *, 8> transferIndices(transfer->getIndices()); SmallVector<Value *, 8> transferIndices(transfer->getIndices());
accesses = edsc::makeNewExprs(transferIndices.size()); accesses = edsc::makeNewExprs(transferIndices.size(),
this->rewriter->getIndexType());
emitter.bind(edsc::Bindable(scalarMemRef), transfer->getMemRef()) emitter.bind(edsc::Bindable(scalarMemRef), transfer->getMemRef())
.template bindZipRangeConstants<ConstantIndexOp>( .template bindZipRangeConstants<ConstantIndexOp>(
llvm::zip(vectorSizes, vectorShape)) llvm::zip(vectorSizes, vectorShape))
@ -321,7 +324,11 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
auto &lbs = accessInfo.lowerBoundsExprs; auto &lbs = accessInfo.lowerBoundsExprs;
auto &ubs = accessInfo.upperBoundsExprs; auto &ubs = accessInfo.upperBoundsExprs;
auto &steps = accessInfo.stepExprs; auto &steps = accessInfo.stepExprs;
Expr scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
auto vectorType = this->transfer->getVectorType();
auto scalarType = this->transfer->getMemRefType().getElementType();
Expr scalarValue(scalarType), vectorValue(vectorType), tmpAlloc(tmpMemRefType), tmpDealloc(Type{}), vectorView(vectorMemRefType);
auto block = edsc::block({ auto block = edsc::block({
tmpAlloc = alloc(tmpMemRefType), tmpAlloc = alloc(tmpMemRefType),
vectorView = vector_type_cast(Expr(tmpAlloc), vectorMemRefType), vectorView = vector_type_cast(Expr(tmpAlloc), vectorMemRefType),
@ -368,7 +375,7 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
auto accessInfo = makeVectorTransferAccessInfo(); auto accessInfo = makeVectorTransferAccessInfo();
// Bind vector value for the vector_transfer_write. // Bind vector value for the vector_transfer_write.
Expr vectorValue; Expr vectorValue(transfer->getVectorType());
emitter.bind(Bindable(vectorValue), transfer->getVector()); emitter.bind(Bindable(vectorValue), transfer->getVector());
// clang-format off // clang-format off
@ -376,7 +383,8 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
auto &lbs = accessInfo.lowerBoundsExprs; auto &lbs = accessInfo.lowerBoundsExprs;
auto &ubs = accessInfo.upperBoundsExprs; auto &ubs = accessInfo.upperBoundsExprs;
auto &steps = accessInfo.stepExprs; auto &steps = accessInfo.stepExprs;
Expr scalarValue, tmpAlloc, tmpDealloc, vectorView; auto scalarType = tmpMemRefType.getElementType();
Expr scalarValue(scalarType), tmpAlloc(tmpMemRefType), tmpDealloc(Type{}), vectorView(vectorMemRefType);
auto block = edsc::block({ auto block = edsc::block({
tmpAlloc = alloc(tmpMemRefType), tmpAlloc = alloc(tmpMemRefType),
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType), vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),

View File

@ -11,7 +11,7 @@ def X_AddOp : Op<"x.add">,
// TODO: extract referenceImplementation to Op. // TODO: extract referenceImplementation to Op.
// TODO: shrink the reference implementation // TODO: shrink the reference implementation
code referenceImplementation = [{ code referenceImplementation = [{
auto ivs = makeNewExprs(view_A.rank()); auto ivs = makeNewExprs(view_A.rank(), builder.getIndexType());
// TODO(jpienaar@): automate the positional/named extraction. Need to be a // TODO(jpienaar@): automate the positional/named extraction. Need to be a
// bit careful about things memref (from which a "view" can be extracted) // bit careful about things memref (from which a "view" can be extracted)
// and the rest (see ReferenceImplGen.cpp). // and the rest (see ReferenceImplGen.cpp).