Python bindings: drop MLIREmitter and related functionality

This completes the transition of Python bindings to use the declarative
    builders infrastructure instead of the now-deprecated EDSC emitter
    infrastructure.  The relevant unit tests have been replicated using the new
    functionality and the remaining end-to-end compilation tests have been updated
    accordingly.  The latter show an improvement in brevity and readability.

--

PiperOrigin-RevId: 241713489
This commit is contained in:
Alex Zinenko 2019-04-03 05:43:44 -07:00 committed by Mehdi Amini
parent 509619829d
commit 7a30ac97c8
3 changed files with 22 additions and 934 deletions

View File

@ -27,8 +27,6 @@
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Function.h"
@ -225,15 +223,6 @@ private:
std::unique_ptr<mlir::ExecutionEngine> engine;
};
struct ContextManager {
void enter() { context = new ScopedEDSCContext(); }
void exit(py::object, py::object, py::object) {
delete context;
context = nullptr;
}
mlir::edsc::ScopedEDSCContext *context;
};
struct PythonFunctionContext {
PythonFunctionContext(PythonFunction f) : function(f) {}
PythonFunctionContext(PythonMLIRModule &module, const std::string &name,
@ -268,18 +257,6 @@ PythonFunctionContext PythonMLIRModule::makeFunctionContext(
return PythonFunctionContext(func);
}
struct PythonExpr {
PythonExpr() : expr{nullptr} {}
PythonExpr(const PythonBindable &bindable);
PythonExpr(const edsc_expr_t &expr) : expr{expr} {}
operator edsc_expr_t() { return expr; }
std::string str() {
assert(expr && "unexpected empty expr");
return Expr(*this).str();
}
edsc_expr_t expr;
};
struct PythonBlockHandle {
PythonBlockHandle() : value(nullptr) {}
PythonBlockHandle(const PythonBlockHandle &other) = default;
@ -440,45 +417,6 @@ public:
BlockBuilder *builder = nullptr;
};
struct PythonBindable : public PythonExpr {
explicit PythonBindable(const PythonType &type)
: PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
PythonBindable(PythonExpr expr) : PythonExpr(expr) {
assert(Expr(expr).isa<Bindable>() && "Expected Bindable");
}
std::string str() {
assert(expr && "unexpected empty expr");
return Expr(expr).str();
}
};
struct PythonStmt {
PythonStmt() : stmt{nullptr} {}
PythonStmt(const edsc_stmt_t &stmt) : stmt{stmt} {}
PythonStmt(const PythonExpr &e) : stmt{makeStmt(e.expr)} {}
operator edsc_stmt_t() { return stmt; }
std::string str() {
assert(stmt && "unexpected empty stmt");
return Stmt(stmt).str();
}
edsc_stmt_t stmt;
};
struct PythonBlock {
PythonBlock() : blk{nullptr} {}
PythonBlock(const edsc_block_t &other) : blk{other} {}
PythonBlock(const PythonBlock &other) = default;
operator edsc_block_t() { return blk; }
std::string str() {
assert(blk && "unexpected empty block");
return StmtBlock(blk).str();
}
PythonBlock set(const py::list &stmts);
edsc_block_t blk;
};
struct PythonAttribute {
PythonAttribute() : attr(nullptr) {}
PythonAttribute(const mlir_attr_t &a) : attr(a) {}
@ -550,12 +488,6 @@ private:
std::unordered_map<std::string, PythonAttribute> attrs;
};
struct PythonIndexed : public edsc_indexed_t {
PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {}
PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {}
operator PythonExpr() { return PythonExpr(base); }
};
struct PythonIndexedValue {
explicit PythonIndexedValue(PythonType type)
: indexed(Type::getFromOpaquePointer(type.type)) {}
@ -584,57 +516,6 @@ struct PythonIndexedValue {
IndexedValue indexed;
};
struct PythonMaxExpr {
PythonMaxExpr() : expr(nullptr) {}
PythonMaxExpr(const edsc_max_expr_t &e) : expr(e) {}
operator edsc_max_expr_t() { return expr; }
edsc_max_expr_t expr;
};
struct PythonMinExpr {
PythonMinExpr() : expr(nullptr) {}
PythonMinExpr(const edsc_min_expr_t &e) : expr(e) {}
operator edsc_min_expr_t() { return expr; }
edsc_min_expr_t expr;
};
struct MLIRFunctionEmitter {
MLIRFunctionEmitter(PythonFunction f)
: currentFunction(reinterpret_cast<mlir::Function *>(f.function)),
currentBuilder(currentFunction),
emitter(&currentBuilder, currentFunction->getLoc()) {}
PythonExpr bindConstantBF16(double value);
PythonExpr bindConstantF16(float value);
PythonExpr bindConstantF32(float value);
PythonExpr bindConstantF64(double value);
PythonExpr bindConstantInt(int64_t value, unsigned bitwidth);
PythonExpr bindConstantIndex(int64_t value);
PythonExpr bindConstantFunction(PythonFunction func);
PythonExpr bindFunctionArgument(unsigned pos);
py::list bindFunctionArguments();
py::list bindFunctionArgumentView(unsigned pos);
py::list bindMemRefShape(PythonExpr boundMemRef);
py::list bindIndexedMemRefShape(PythonIndexed boundMemRef) {
return bindMemRefShape(boundMemRef.base);
}
py::list bindMemRefView(PythonExpr boundMemRef);
py::list bindIndexedMemRefView(PythonIndexed boundMemRef) {
return bindMemRefView(boundMemRef.base);
}
void emit(PythonStmt stmt);
void emitBlock(PythonBlock block);
void emitBlockBody(PythonBlock block);
private:
mlir::Function *currentFunction;
mlir::FuncBuilder currentBuilder;
mlir::edsc::MLIREmitter emitter;
edsc_mlir_emitter_t c_emitter;
};
template <typename ListTy, typename PythonTy, typename Ty>
ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
for (auto &inp : list) {
@ -643,120 +524,11 @@ ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
return ListTy{owning.data(), owning.size()};
}
static edsc_stmt_list_t makeCStmts(llvm::SmallVectorImpl<edsc_stmt_t> &owning,
const py::list &stmts) {
return makeCList<edsc_stmt_list_t, PythonStmt>(owning, stmts);
}
static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl<edsc_expr_t> &owning,
const py::list &exprs) {
return makeCList<edsc_expr_list_t, PythonExpr>(owning, exprs);
}
static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning,
const py::list &types) {
return makeCList<mlir_type_list_t, PythonType>(owning, types);
}
static edsc_block_list_t
makeCBlocks(llvm::SmallVectorImpl<edsc_block_t> &owning,
const py::list &blocks) {
return makeCList<edsc_block_list_t, PythonBlock>(owning, blocks);
}
PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {}
PythonExpr MLIRFunctionEmitter::bindConstantBF16(double value) {
return ::bindConstantBF16(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantF16(float value) {
return ::bindConstantF16(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantF32(float value) {
return ::bindConstantF32(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantF64(double value) {
return ::bindConstantF64(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantInt(int64_t value,
unsigned bitwidth) {
return ::bindConstantInt(edsc_mlir_emitter_t{&emitter}, value, bitwidth);
}
PythonExpr MLIRFunctionEmitter::bindConstantIndex(int64_t value) {
return ::bindConstantIndex(edsc_mlir_emitter_t{&emitter}, value);
}
PythonExpr MLIRFunctionEmitter::bindConstantFunction(PythonFunction func) {
return ::bindConstantFunction(edsc_mlir_emitter_t{&emitter}, func);
}
PythonExpr MLIRFunctionEmitter::bindFunctionArgument(unsigned pos) {
return ::bindFunctionArgument(edsc_mlir_emitter_t{&emitter},
mlir_func_t{currentFunction}, pos);
}
PythonExpr getPythonType(edsc_expr_t e) { return PythonExpr(e); }
template <typename T> py::list makePyList(llvm::ArrayRef<T> owningResults) {
py::list res;
for (auto e : owningResults) {
res.append(getPythonType(e));
}
return res;
}
py::list MLIRFunctionEmitter::bindFunctionArguments() {
auto arity = getFunctionArity(mlir_func_t{currentFunction});
llvm::SmallVector<edsc_expr_t, 8> owningResults(arity);
edsc_expr_list_t results{owningResults.data(), owningResults.size()};
::bindFunctionArguments(edsc_mlir_emitter_t{&emitter},
mlir_func_t{currentFunction}, &results);
return makePyList(ArrayRef<edsc_expr_t>{owningResults});
}
py::list MLIRFunctionEmitter::bindMemRefShape(PythonExpr boundMemRef) {
auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef);
llvm::SmallVector<edsc_expr_t, 8> owningShapes(rank);
edsc_expr_list_t resultShapes{owningShapes.data(), owningShapes.size()};
::bindMemRefShape(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultShapes);
return makePyList(ArrayRef<edsc_expr_t>{owningShapes});
}
py::list MLIRFunctionEmitter::bindMemRefView(PythonExpr boundMemRef) {
auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef);
// Own the PythonExpr for the arg as well as all its dims.
llvm::SmallVector<edsc_expr_t, 8> owningLbs(rank);
llvm::SmallVector<edsc_expr_t, 8> owningUbs(rank);
llvm::SmallVector<edsc_expr_t, 8> owningSteps(rank);
edsc_expr_list_t resultLbs{owningLbs.data(), owningLbs.size()};
edsc_expr_list_t resultUbs{owningUbs.data(), owningUbs.size()};
edsc_expr_list_t resultSteps{owningSteps.data(), owningSteps.size()};
::bindMemRefView(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultLbs,
&resultUbs, &resultSteps);
py::list res;
res.append(makePyList(ArrayRef<edsc_expr_t>{owningLbs}));
res.append(makePyList(ArrayRef<edsc_expr_t>{owningUbs}));
res.append(makePyList(ArrayRef<edsc_expr_t>{owningSteps}));
return res;
}
void MLIRFunctionEmitter::emit(PythonStmt stmt) {
emitter.emitStmt(Stmt(stmt));
}
void MLIRFunctionEmitter::emitBlock(PythonBlock block) {
emitter.emitBlock(StmtBlock(block));
}
void MLIRFunctionEmitter::emitBlockBody(PythonBlock block) {
emitter.emitStmts(StmtBlock(block).getBody());
}
PythonFunction
PythonMLIRModule::declareFunction(const std::string &name,
const py::list &inputs,
@ -811,29 +583,6 @@ PythonMLIRModule::declareFunction(const std::string &name,
return func;
}
PythonExpr PythonMLIRModule::op(const std::string &name, PythonType type,
const py::list &arguments,
const py::list &successors,
py::kwargs attributes) {
SmallVector<edsc_expr_t, 8> owningExprs;
SmallVector<edsc_block_t, 4> owningBlocks;
SmallVector<mlir_named_attr_t, 4> owningAttrs;
SmallVector<std::string, 4> owningAttrNames;
owningAttrs.reserve(attributes.size());
owningAttrNames.reserve(attributes.size());
for (const auto &kvp : attributes) {
owningAttrNames.push_back(kvp.first.str());
auto value = kvp.second.cast<PythonAttribute>();
owningAttrs.push_back({owningAttrNames.back().c_str(), value});
}
return PythonExpr(::Op(mlir_context_t(&mlirContext), name.c_str(), type,
makeCExprs(owningExprs, arguments),
makeCBlocks(owningBlocks, successors),
{owningAttrs.data(), owningAttrs.size()}));
}
PythonAttributedType PythonType::attachAttributeDict(
const std::unordered_map<std::string, PythonAttribute> &attrs) const {
return PythonAttributedType(*this, attrs);
@ -847,88 +596,10 @@ PythonAttribute PythonMLIRModule::boolAttr(bool value) {
return PythonAttribute(::makeBoolAttr(&mlirContext, value));
}
PythonBlock PythonBlock::set(const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
::BlockSetBody(blk, makeCStmts(owning, stmts));
return *this;
}
PythonExpr dispatchCall(py::args args, py::kwargs kwargs) {
assert(args.size() != 0);
llvm::SmallVector<edsc_expr_t, 8> exprs;
exprs.reserve(args.size());
for (auto arg : args) {
exprs.push_back(arg.cast<PythonExpr>());
}
edsc_expr_list_t operands{exprs.data() + 1, exprs.size() - 1};
if (kwargs && kwargs.contains("result")) {
for (const auto &kvp : kwargs) {
if (static_cast<std::string>(kvp.first.str()) == "result")
return ::Call1(exprs.front(), kvp.second.cast<PythonType>(), operands);
}
}
return ::Call0(exprs.front(), operands);
}
PYBIND11_MODULE(pybind, m) {
m.doc() =
"Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
m.def("version", []() { return "EDSC Python extensions v0.0"; });
m.def("initContext",
[]() { return static_cast<void *>(new ScopedEDSCContext()); });
m.def("deleteContext",
[](void *ctx) { delete reinterpret_cast<ScopedEDSCContext *>(ctx); });
m.def("Block", [](const py::list &args, const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
SmallVector<edsc_expr_t, 8> owningArgs;
return PythonBlock(
::Block(makeCExprs(owningArgs, args), makeCStmts(owning, stmts)));
});
m.def("Block", [](const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
edsc_expr_list_t args{nullptr, 0};
return PythonBlock(::Block(args, makeCStmts(owning, stmts)));
});
m.def(
"Branch",
[](PythonBlock destination, const py::list &operands) {
SmallVector<edsc_expr_t, 8> owning;
return PythonStmt(::Branch(destination, makeCExprs(owning, operands)));
},
py::arg("destination"), py::arg("operands") = py::list());
m.def("CondBranch",
[](PythonExpr condition, PythonBlock trueDestination,
const py::list &trueOperands, PythonBlock falseDestination,
const py::list &falseOperands) {
SmallVector<edsc_expr_t, 8> owningTrue;
SmallVector<edsc_expr_t, 8> owningFalse;
return PythonStmt(::CondBranch(
condition, trueDestination, makeCExprs(owningTrue, trueOperands),
falseDestination, makeCExprs(owningFalse, falseOperands)));
});
m.def("CondBranch", [](PythonExpr condition, PythonBlock trueDestination,
PythonBlock falseDestination) {
edsc_expr_list_t emptyList;
emptyList.exprs = nullptr;
emptyList.n = 0;
return PythonStmt(::CondBranch(condition, trueDestination, emptyList,
falseDestination, emptyList));
});
m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs,
const py::list &steps, const py::list &stmts) {
SmallVector<edsc_expr_t, 8> owningIVs;
SmallVector<edsc_expr_t, 8> owningLBs;
SmallVector<edsc_expr_t, 8> owningUBs;
SmallVector<edsc_expr_t, 8> owningSteps;
SmallVector<edsc_stmt_t, 8> owningStmts;
return PythonStmt(
::ForNest(makeCExprs(owningIVs, ivs), makeCExprs(owningLBs, lbs),
makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps),
makeCStmts(owningStmts, stmts)));
});
m.def("version", []() { return "EDSC Python extensions v1.0"; });
py::class_<PythonLoopContext>(
m, "LoopContext", "A context for building the body of a 'for' loop")
@ -1032,68 +703,6 @@ PYBIND11_MODULE(pybind, m) {
return ValueHandle::create(name, operandHandles, types, attrs);
});
m.def("Max", [](const py::list &args) {
SmallVector<edsc_expr_t, 8> owning;
return PythonMaxExpr(::Max(makeCExprs(owning, args)));
});
m.def("Min", [](const py::list &args) {
SmallVector<edsc_expr_t, 8> owning;
return PythonMinExpr(::Min(makeCExprs(owning, args)));
});
m.def("For", [](PythonExpr iv, PythonExpr lb, PythonExpr ub, PythonExpr step,
const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts)));
});
m.def("For", [](PythonExpr iv, PythonMaxExpr lb, PythonMinExpr ub,
PythonExpr step, const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonStmt(::MaxMinFor(iv, lb, ub, step, makeCStmts(owning, stmts)));
});
m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) {
return PythonExpr(::Select(cond, e1, e2));
});
m.def("Return", []() {
return PythonStmt(::Return(edsc_expr_list_t{nullptr, 0}));
});
m.def("Return", [](const py::list &returns) {
SmallVector<edsc_expr_t, 8> owningExprs;
return PythonStmt(::Return(makeCExprs(owningExprs, returns)));
});
m.def("ConstantInteger", [](PythonType type, int64_t value) {
return PythonExpr(::ConstantInteger(type, value));
});
#define DEFINE_PYBIND_BINARY_OP(PYTHON_NAME, C_NAME) \
m.def(PYTHON_NAME, [](PythonExpr e1, PythonExpr e2) { \
return PythonExpr(::C_NAME(e1, e2)); \
});
DEFINE_PYBIND_BINARY_OP("Add", Add);
DEFINE_PYBIND_BINARY_OP("Mul", Mul);
DEFINE_PYBIND_BINARY_OP("Sub", Sub);
DEFINE_PYBIND_BINARY_OP("Div", Div);
DEFINE_PYBIND_BINARY_OP("Rem", Rem);
DEFINE_PYBIND_BINARY_OP("FloorDiv", FloorDiv);
DEFINE_PYBIND_BINARY_OP("CeilDiv", CeilDiv);
DEFINE_PYBIND_BINARY_OP("LT", LT);
DEFINE_PYBIND_BINARY_OP("LE", LE);
DEFINE_PYBIND_BINARY_OP("GT", GT);
DEFINE_PYBIND_BINARY_OP("GE", GE);
DEFINE_PYBIND_BINARY_OP("EQ", EQ);
DEFINE_PYBIND_BINARY_OP("NE", NE);
DEFINE_PYBIND_BINARY_OP("And", And);
DEFINE_PYBIND_BINARY_OP("Or", Or);
#undef DEFINE_PYBIND_BINARY_OP
#define DEFINE_PYBIND_UNARY_OP(PYTHON_NAME, C_NAME) \
m.def(PYTHON_NAME, [](PythonExpr e1) { return PythonExpr(::C_NAME(e1)); });
DEFINE_PYBIND_UNARY_OP("Negate", Negate);
#undef DEFINE_PYBIND_UNARY_OP
py::class_<PythonFunction>(m, "Function",
"Wrapping class for mlir::Function.")
.def(py::init<PythonFunction>())
@ -1104,12 +713,6 @@ PYBIND11_MODULE(pybind, m) {
.def("arg", &PythonFunction::arg,
"Get the ValueHandle to the indexed argument of the function");
py::class_<PythonBlock>(m, "StmtBlock",
"Wrapping class for mlir::edsc::StmtBlock")
.def(py::init<PythonBlock>())
.def("set", &PythonBlock::set)
.def("__str__", &PythonBlock::str);
py::class_<PythonAttribute>(m, "Attribute",
"Wrapping class for mlir::Attribute")
.def(py::init<PythonAttribute>())
@ -1143,9 +746,6 @@ PYBIND11_MODULE(pybind, m) {
"directly require integration with a tensor library (e.g. numpy). This "
"is left as the prerogative of libraries and frameworks for now.")
.def(py::init<>())
.def("op", &PythonMLIRModule::op, py::arg("name"), py::arg("type"),
py::arg("arguments"), py::arg("successors") = py::list(),
"Creates a new expression identified by its canonical name.")
.def("boolAttr", &PythonMLIRModule::boolAttr,
"Creates an mlir::BoolAttr with the given value")
.def(
@ -1198,15 +798,6 @@ PYBIND11_MODULE(pybind, m) {
.def("__str__", &PythonMLIRModule::getIR,
"Get the string representation of the module");
py::class_<ContextManager>(
m, "ContextManager",
"An EDSC context manager is the memory arena containing all the EDSC "
"allocations.\nUsage:\n\n"
"with E.ContextManager() as _:\n i = E.Expr(E.Bindable())\n ...")
.def(py::init<>())
.def("__enter__", &ContextManager::enter)
.def("__exit__", &ContextManager::exit);
py::class_<PythonFunctionContext>(
m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext")
.def(py::init<PythonFunction>())
@ -1313,131 +904,6 @@ PYBIND11_MODULE(pybind, m) {
.def(py::init<PythonValueHandle>())
.def("load", &PythonIndexedValue::load)
.def("store", &PythonIndexedValue::store);
py::class_<MLIRFunctionEmitter>(
m, "MLIRFunctionEmitter",
"An MLIRFunctionEmitter is used to fill an empty function body. This is "
"a staged process:\n"
" 1. create or retrieve an mlir::Function `f` with an empty body;\n"
" 2. make an `MLIRFunctionEmitter(f)` to build the current function;\n"
" 3. create leaf Expr that are either Bindable or already Expr that are"
" bound to constants and function arguments by using methods of "
" `MLIRFunctionEmitter`;\n"
" 4. build the function body using Expr, Indexed and Stmt;\n"
" 5. emit the MLIR to implement the function body.")
.def(py::init<PythonFunction>())
.def("bind_constant_bf16", &MLIRFunctionEmitter::bindConstantBF16)
.def("bind_constant_f16", &MLIRFunctionEmitter::bindConstantF16)
.def("bind_constant_f32", &MLIRFunctionEmitter::bindConstantF32)
.def("bind_constant_f64", &MLIRFunctionEmitter::bindConstantF64)
.def("bind_constant_int", &MLIRFunctionEmitter::bindConstantInt)
.def("bind_constant_index", &MLIRFunctionEmitter::bindConstantIndex)
.def("bind_constant_function", &MLIRFunctionEmitter::bindConstantFunction)
.def("bind_function_argument", &MLIRFunctionEmitter::bindFunctionArgument,
"Returns an Expr that has been bound to a positional argument in "
"the current Function.")
.def("bind_function_arguments",
&MLIRFunctionEmitter::bindFunctionArguments,
"Returns a list of Expr where each Expr has been bound to the "
"corresponding positional argument in the current Function.")
.def("bind_memref_shape", &MLIRFunctionEmitter::bindMemRefShape,
"Returns a list of Expr where each Expr has been bound to the "
"corresponding dimension of the memref.")
.def("bind_memref_view", &MLIRFunctionEmitter::bindMemRefView,
"Returns three lists (lower bound, upper bound and step) of Expr "
"where each triplet of Expr has been bound to the minimal offset, "
"extent and stride of the corresponding dimension of the memref.")
.def("bind_indexed_shape", &MLIRFunctionEmitter::bindIndexedMemRefShape,
"Same as bind_memref_shape but returns a list of `Indexed` that "
"support load and store operations")
.def("bind_indexed_view", &MLIRFunctionEmitter::bindIndexedMemRefView,
"Same as bind_memref_view but returns lists of `Indexed` that "
"support load and store operations")
.def("emit", &MLIRFunctionEmitter::emit,
"Emits the MLIR for the EDSC expressions and statements in the "
"current function body.")
.def("emit", &MLIRFunctionEmitter::emitBlock,
"Emits the MLIR for the EDSC statements into a new block")
.def("emit_inplace", &MLIRFunctionEmitter::emitBlockBody,
"Emits the MLIR for the EDSC statements contained in a EDSC block "
"into the current function body without creating a new block");
py::class_<PythonExpr>(m, "Expr", "Wrapping class for mlir::edsc::Expr")
.def(py::init<PythonBindable>())
.def("__add__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Add(e1, e2)); })
.def("__sub__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Sub(e1, e2)); })
.def("__mul__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Mul(e1, e2)); })
.def("__div__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Div(e1, e2)); })
.def("__truediv__",
[](PythonExpr e1, PythonExpr e2) {
return PythonExpr(::Div(e1, e2));
})
.def("__mod__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Rem(e1, e2)); })
.def("__lt__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::LT(e1, e2)); })
.def("__le__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::LE(e1, e2)); })
.def("__gt__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::GT(e1, e2)); })
.def("__ge__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::GE(e1, e2)); })
.def("__eq__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::EQ(e1, e2)); })
.def("__ne__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::NE(e1, e2)); })
.def("__and__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::And(e1, e2)); })
.def("__or__", [](PythonExpr e1,
PythonExpr e2) { return PythonExpr(::Or(e1, e2)); })
.def("__invert__", [](PythonExpr e) { return PythonExpr(::Negate(e)); })
.def("__call__", &dispatchCall)
.def("__str__", &PythonExpr::str,
R"DOC(Returns the string value for the Expr)DOC");
py::class_<PythonBindable>(
m, "Bindable",
"Wrapping class for mlir::edsc::Bindable.\nA Bindable is a special Expr "
"that can be bound manually to specific MLIR SSA Values.")
.def(py::init<PythonType>())
.def("__str__", &PythonBindable::str);
py::class_<PythonStmt>(m, "Stmt", "Wrapping class for mlir::edsc::Stmt.")
.def(py::init<PythonExpr>())
.def("__str__", &PythonStmt::str,
R"DOC(Returns the string value for the Expr)DOC");
py::class_<PythonIndexed>(
m, "Indexed",
"Wrapping class for mlir::edsc::Indexed.\nAn Indexed is a wrapper class "
"that support load and store operations.")
.def(py::init<PythonExpr>(), R"DOC(Build from existing Expr)DOC")
.def(py::init<PythonBindable>(), R"DOC(Build from existing Bindable)DOC")
.def(
"load",
[](PythonIndexed &instance, const py::list &indices) {
SmallVector<edsc_expr_t, 8> owning;
return PythonExpr(Load(instance, makeCExprs(owning, indices)));
},
R"DOC(Returns an Expr that loads from an Indexed)DOC")
.def(
"store",
[](PythonIndexed &instance, const py::list &indices,
PythonExpr value) {
SmallVector<edsc_expr_t, 8> owning;
return PythonStmt(
Store(value, instance, makeCExprs(owning, indices)));
},
R"DOC(Returns the Stmt that stores into an Indexed)DOC");
py::class_<PythonMaxExpr>(m, "MaxExpr",
"Wrapping class for mlir::edsc::MaxExpr");
py::class_<PythonMinExpr>(m, "MinExpr",
"Wrapping class for mlir::edsc::MinExpr");
}
} // namespace python

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python2 and 3 test for the MLIR EDSC C API and Python bindings"""
"""Python2 and 3 test for the MLIR EDSC Python bindings"""
from __future__ import absolute_import
from __future__ import division
@ -400,212 +400,6 @@ class EdscTest(unittest.TestCase):
self.assertIn("%2 = mulf %0, %1 : f32", code)
self.assertIn("store %2, %arg2[%i0, %i1] : memref<32x32xf32>", code)
def testBindables(self):
with E.ContextManager():
i = E.Expr(E.Bindable(self.i32Type))
self.assertIn("$1", i.__str__())
def testOneExpr(self):
with E.ContextManager():
i, lb, ub = list(
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
expr = E.Mul(i, E.Add(lb, ub))
str = expr.__str__()
self.assertIn("($1 * ($2 + $3))", str)
def testCustomOp(self):
with E.ContextManager():
a, b = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2))
c1 = self.module.op(
"std.constant",
self.i32Type, [],
value=self.module.integerAttr(self.i32Type, 42))
expr = self.module.op("std.addi", self.i32Type, [c1, b])
str = expr.__str__()
self.assertIn("addi(42, $2)", str)
def testOneLoop(self):
with E.ContextManager():
i, lb, ub, step = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
str = loop.__str__()
self.assertIn("for($1 = $2 to $3 step $4) {", str)
self.assertIn(" = ($2 + $3)", str)
def testTwoLoops(self):
with E.ContextManager():
i, lb, ub, step = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])])
str = loop.__str__()
self.assertIn("for($1 = $2 to $3 step $4) {", str)
self.assertIn("for($1 = $2 to $3 step $4) {", str)
self.assertIn("$5 = $1;", str)
def testNestedLoops(self):
with E.ContextManager():
i, lb, ub = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
step = E.ConstantInteger(self.indexType, 42)
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
loop = E.For(ivs, lbs, ubs, steps, [
E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]),
])
str = loop.__str__()
self.assertIn("for($5 = $9 to $13 step $17) {", str)
self.assertIn("for($6 = $10 to $14 step $18) {", str)
self.assertIn("for($7 = $11 to $15 step $19) {", str)
self.assertIn("for($8 = $12 to $16 step $20) {", str)
self.assertIn("for($1 = $2 to $3 step 42) {", str)
self.assertIn("= (($3 * 42) + $2 * -1);", str)
def testMaxMinLoop(self):
with E.ContextManager():
i = E.Expr(E.Bindable(self.indexType))
step = E.Expr(E.Bindable(self.indexType))
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
loop = E.For(i, E.Max(lbs), E.Min(ubs), step, [])
s = str(loop)
self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)",
s)
def testIndexed(self):
with E.ContextManager():
i, j, k = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
str = stmt.__str__()
self.assertIn(" = std.store(", str)
def testMatmul(self):
with E.ContextManager():
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
i, j, k = ivs[0], ivs[1], ivs[2]
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
loop = E.For(
ivs, lbs, ubs, steps,
[C.store([i, j],
C.load([i, j]) + A.load([i, k]) * B.load([k, j]))])
str = loop.__str__()
self.assertIn("for($1 = $4 to $7 step $10) {", str)
self.assertIn("for($2 = $5 to $8 step $11) {", str)
self.assertIn("for($3 = $6 to $9 step $12) {", str)
self.assertIn(" = std.store", str)
def testArithmetic(self):
with E.ContextManager():
i, j, k, l = list(
map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)]))
stmt = i % j + j * k - l / k
str = stmt.__str__()
self.assertIn("((($1 % $2) + ($2 * $3)) - ($4 / $3))", str)
def testBoolean(self):
with E.ContextManager():
i, j, k, l = list(
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(4)]))
stmt1 = (i < j) & (j >= k)
stmt2 = ~(stmt1 | (k == l))
str = stmt2.__str__()
# Note that "a | b" is currently implemented as ~(~a && ~b) and "~a" is
# currently implemented as "constant 1 - a", which leads to this
# expression.
self.assertIn(
"(1 - (1 - ((1 - (($1 < $2) && ($2 >= $3))) && (1 - ($3 == $4)))))",
str)
def testSelect(self):
with E.ContextManager():
i, j, k = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
stmt = E.Select(i > j, i, j)
str = stmt.__str__()
self.assertIn("select(($1 > $2), $1, $2)", str)
def testCall(self):
with E.ContextManager():
module = E.MLIRModule()
f32 = module.make_scalar_type("f32")
func, arg = [E.Expr(E.Bindable(f32)) for _ in range(2)]
code = func(arg, result=f32)
self.assertIn("@$1($2)", str(code))
def testBlock(self):
with E.ContextManager():
i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)]))
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
str = stmt.__str__()
self.assertIn("^bb", str)
self.assertIn(" = ($1 + $2)", str)
self.assertIn(" = ($1 - $2)", str)
def testBlockArgs(self):
with E.ContextManager():
module = E.MLIRModule()
t = module.make_scalar_type("i", 32)
i, j = list(map(E.Expr, [E.Bindable(t) for _ in range(2)]))
stmt = E.Block([i, j], [E.Stmt(i + j)])
str = stmt.__str__()
self.assertIn("^bb", str)
self.assertIn("($1, $2):", str)
self.assertIn("($1 + $2)", str)
def testBranch(self):
with E.ContextManager():
i, j = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(2)]))
b1 = E.Block([E.Stmt(i + j)])
b2 = E.Block([E.Branch(b1)])
str1 = b1.__str__()
str2 = b2.__str__()
self.assertIn("^bb1:\n" + "$4 = ($1 + $2)", str1)
self.assertIn("^bb2:\n" + "$6 = br ^bb1", str2)
def testBranchArgs(self):
with E.ContextManager():
b1arg, b2arg = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2))
# Declare empty blocks with arguments and bind those arguments.
b1 = E.Block([b1arg], [])
b2 = E.Block([b2arg], [])
one = E.ConstantInteger(self.i32Type, 1)
# Make blocks branch to each other in a sort of infinite loop.
# This checks that the EDSC implementation does not fall into such loop.
b1.set([E.Branch(b2, [b1arg + one])])
b2.set([E.Branch(b1, [b2arg])])
str1 = b1.__str__()
str2 = b2.__str__()
self.assertIn("^bb1($1):\n" + "$6 = br ^bb2(($1 + 1))", str1)
self.assertIn("^bb2($2):\n" + "$8 = br ^bb1($2)", str2)
def testCondBranch(self):
with E.ContextManager():
cond = E.Expr(E.Bindable(self.boolType))
b1 = E.Block([])
b2 = E.Block([])
b3 = E.Block([E.CondBranch(cond, b1, b2)])
str = b3.__str__()
self.assertIn("cond_br($1, ^bb1, ^bb2)", str)
def testCondBranchArgs(self):
with E.ContextManager():
arg1, arg2, arg3 = (E.Expr(E.Bindable(self.i32Type)) for _ in range(3))
expr1, expr2, expr3 = (E.Expr(E.Bindable(self.i32Type)) for _ in range(3))
cond = E.Expr(E.Bindable(self.boolType))
b1 = E.Block([arg1], [])
b2 = E.Block([arg2, arg3], [])
b3 = E.Block([E.CondBranch(cond, b1, [expr1], b2, [expr2, expr3])])
str = b3.__str__()
self.assertIn("cond_br($7, ^bb1($4), ^bb2($5, $6))", str)
def testMLIRScalarTypes(self):
module = E.MLIRModule()
t = module.make_scalar_type("bf16")
@ -641,50 +435,6 @@ class EdscTest(unittest.TestCase):
f = module.make_function("sqrtf", [t], [t])
self.assertIn("func @sqrtf(%arg0: f32) -> f32", f.__str__())
def testMLIRConstantEmission(self):
module = E.MLIRModule()
f = module.make_function("constants", [], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
emitter.bind_constant_bf16(1.23)
emitter.bind_constant_f16(1.23)
emitter.bind_constant_f32(1.23)
emitter.bind_constant_f64(1.23)
emitter.bind_constant_int(1, 1)
emitter.bind_constant_int(123, 8)
emitter.bind_constant_int(123, 16)
emitter.bind_constant_int(123, 32)
emitter.bind_constant_index(123)
emitter.bind_constant_function(f)
str = f.__str__()
self.assertIn("constant 1.230000e+00 : bf16", str)
self.assertIn("constant 1.230470e+00 : f16", str)
self.assertIn("constant 1.230000e+00 : f32", str)
self.assertIn("constant 1.230000e+00 : f64", str)
self.assertIn("constant 1 : i1", str)
self.assertIn("constant 123 : i8", str)
self.assertIn("constant 123 : i16", str)
self.assertIn("constant 123 : i32", str)
self.assertIn("constant 123 : index", str)
self.assertIn("constant @constants : () -> ()", str)
def testMLIRBuiltinEmission(self):
module = E.MLIRModule()
m = module.make_memref_type(self.f32Type, [10]) # f32 tensor
f = module.make_function("call_builtin", [m, m], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
fn = module.declare_function("sqrtf", [self.f32Type], [self.f32Type])
fn = emitter.bind_constant_function(fn)
zero = emitter.bind_constant_index(0)
emitter.emit_inplace(E.Block([
output.store([zero], fn(input.load([zero]), result=self.f32Type))
]))
str = f.__str__()
self.assertIn("%f = constant @sqrtf : (f32) -> f32", str)
self.assertIn("call_indirect %f(%0) : (f32) -> f32", str)
def testFunctionDeclaration(self):
module = E.MLIRModule()
boolAttr = self.module.boolAttr(True)
@ -697,99 +447,35 @@ class EdscTest(unittest.TestCase):
"func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias: true}, memref<10xf32> {readonly: true})",
str)
def testMLIRBooleanEmission(self):
def testMLIRBooleanCompilation(self):
m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
f = self.module.make_function("mkbooltensor", [m, m], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
i = E.Expr(E.Bindable(self.indexType))
j = E.Expr(E.Bindable(self.indexType))
k = E.Expr(E.Bindable(self.indexType))
idxs = [i, j, k]
zero = emitter.bind_constant_index(0)
one = emitter.bind_constant_index(1)
ten = emitter.bind_constant_index(10)
b1 = E.And(i < j, j < k)
b2 = E.Negate(b1)
b3 = E.Or(b2, k < j)
loop = E.Block([
E.For(idxs, [zero]*3, [ten]*3, [one]*3, [
output.store([i], E.And(input.load([i]), b3))
]),
E.Return()
])
emitter.emit_inplace(loop)
# str = f.__str__()
# print(str)
with self.module.function_context("mkbooltensor", [m, m], []) as f:
input = E.IndexedValue(f.arg(0))
output = E.IndexedValue(f.arg(1))
zero = E.constant_index(0)
ten = E.constant_index(10)
with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k):
b1 = (i < j) & (j < k)
b2 = ~b1
b3 = b2 | (k < j)
output.store([i], input.load([i]) & b3)
E.ret([])
self.module.compile()
self.assertNotEqual(self.module.get_engine_address(), 0)
def testCustomOpEmission(self):
f = self.module.make_function("fooer", [self.i32Type, self.i32Type], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
funcArg1, funcArg2 = emitter.bind_function_arguments()
boolAttr = self.module.boolAttr(True)
expr = self.module.op(
"foo", self.i32Type, [funcArg1, funcArg2], attr=boolAttr)
block = E.Block([E.Stmt(expr), E.Return()])
emitter.emit_inplace(block)
code = str(f)
self.assertIn('%0 = "foo"(%arg0, %arg1) {attr: true} : (i32, i32) -> i32',
code)
# Create 'addi' using the generic Op interface. We need an operation known
# to the execution engine so that the engine can compile it.
def testCustomOpCompilation(self):
f = self.module.make_function("adder", [self.i32Type], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
funcArg, = emitter.bind_function_arguments()
c1 = self.module.op(
"std.constant",
self.i32Type, [],
with self.module.function_context("adder", [self.i32Type], []) as f:
c1 = E.op(
"std.constant", [], [self.i32Type],
value=self.module.integerAttr(self.i32Type, 42))
expr = self.module.op("std.addi", self.i32Type, [c1, funcArg])
block = E.Block([E.Stmt(expr), E.Return()])
emitter.emit_inplace(block)
self.module.compile()
self.assertNotEqual(self.module.get_engine_address(), 0)
E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
E.ret([])
def testMLIREmission(self):
shape = [3, 4, 5]
m = self.module.make_memref_type(self.f32Type, shape)
f = self.module.make_function("copy", [m, m], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
zero = emitter.bind_constant_index(0)
one = emitter.bind_constant_index(1)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
M, N, O = emitter.bind_indexed_shape(input)
ivs = list(
map(E.Expr, [E.Bindable(self.indexType) for _ in range(len(shape))]))
lbs = [zero, zero, zero]
ubs = [M, N, O]
steps = [one, one, one]
# TODO(ntv): emitter.assertEqual(M, oM) etc
loop = E.Block([
E.For(ivs, lbs, ubs, steps, [output.store(ivs, input.load(ivs))]),
E.Return()
])
emitter.emit_inplace(loop)
# print(f) # uncomment to see the emitted IR
str = f.__str__()
self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""",
str)
self.module.compile()
self.assertNotEqual(self.module.get_engine_address(), 0)
self.module.compile()
self.assertNotEqual(self.module.get_engine_address(), 0)
if __name__ == "__main__":

View File

@ -1,64 +0,0 @@
# Copyright 2019 The MLIR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python3 test for the MLIR EDSC C API and Python bindings"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import google_mlir.bindings.python.pybind as E
class EdscTest(unittest.TestCase):
def testSugaredMLIREmission(self):
shape = [3, 4, 5, 6, 7]
shape_t = [7, 4, 5, 6, 3]
module = E.MLIRModule()
t = module.make_scalar_type("f32")
m = module.make_memref_type(t, shape)
m_t = module.make_memref_type(t, shape_t)
f = module.make_function("copy", [m, m_t], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
lbs, ubs, steps = emitter.bind_indexed_view(input)
i, *ivs, j = list(
map(E.Expr,
[E.Bindable(module.make_index_type()) for _ in range(len(shape))
]))
# n-D type and rank agnostic copy-transpose-first-last (where n >= 2).
loop = E.Block([
E.For([i, *ivs, j], lbs, ubs, steps,
[output.store([i, *ivs, j], input.load([j, *ivs, i]))]),
E.Return()
])
emitter.emit_inplace(loop)
# print(f) # uncomment to see the emitted IR
str = f.__str__()
self.assertIn("load %arg0[%i4, %i1, %i2, %i3, %i0]", str)
self.assertIn("store %0, %arg1[%i0, %i1, %i2, %i3, %i4]", str)
module.compile()
self.assertNotEqual(module.get_engine_address(), 0)
if __name__ == "__main__":
unittest.main()