forked from OSchip/llvm-project
Python bindings: drop MLIREmitter and related functionality
This completes the transition of Python bindings to use the declarative builders infrastructure instead of the now-deprecated EDSC emitter infrastructure. The relevant unit tests have been replicated using the new functionality and the remaining end-to-end compilation tests have been updated accordingly. The latter show an improvement in brevity and readability. -- PiperOrigin-RevId: 241713489
This commit is contained in:
parent
509619829d
commit
7a30ac97c8
|
@ -27,8 +27,6 @@
|
|||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/EDSC/MLIREmitter.h"
|
||||
#include "mlir/EDSC/Types.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
|
@ -225,15 +223,6 @@ private:
|
|||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
};
|
||||
|
||||
struct ContextManager {
|
||||
void enter() { context = new ScopedEDSCContext(); }
|
||||
void exit(py::object, py::object, py::object) {
|
||||
delete context;
|
||||
context = nullptr;
|
||||
}
|
||||
mlir::edsc::ScopedEDSCContext *context;
|
||||
};
|
||||
|
||||
struct PythonFunctionContext {
|
||||
PythonFunctionContext(PythonFunction f) : function(f) {}
|
||||
PythonFunctionContext(PythonMLIRModule &module, const std::string &name,
|
||||
|
@ -268,18 +257,6 @@ PythonFunctionContext PythonMLIRModule::makeFunctionContext(
|
|||
return PythonFunctionContext(func);
|
||||
}
|
||||
|
||||
struct PythonExpr {
|
||||
PythonExpr() : expr{nullptr} {}
|
||||
PythonExpr(const PythonBindable &bindable);
|
||||
PythonExpr(const edsc_expr_t &expr) : expr{expr} {}
|
||||
operator edsc_expr_t() { return expr; }
|
||||
std::string str() {
|
||||
assert(expr && "unexpected empty expr");
|
||||
return Expr(*this).str();
|
||||
}
|
||||
edsc_expr_t expr;
|
||||
};
|
||||
|
||||
struct PythonBlockHandle {
|
||||
PythonBlockHandle() : value(nullptr) {}
|
||||
PythonBlockHandle(const PythonBlockHandle &other) = default;
|
||||
|
@ -440,45 +417,6 @@ public:
|
|||
BlockBuilder *builder = nullptr;
|
||||
};
|
||||
|
||||
struct PythonBindable : public PythonExpr {
|
||||
explicit PythonBindable(const PythonType &type)
|
||||
: PythonExpr(edsc_expr_t{makeBindable(type.type)}) {}
|
||||
PythonBindable(PythonExpr expr) : PythonExpr(expr) {
|
||||
assert(Expr(expr).isa<Bindable>() && "Expected Bindable");
|
||||
}
|
||||
std::string str() {
|
||||
assert(expr && "unexpected empty expr");
|
||||
return Expr(expr).str();
|
||||
}
|
||||
};
|
||||
|
||||
struct PythonStmt {
|
||||
PythonStmt() : stmt{nullptr} {}
|
||||
PythonStmt(const edsc_stmt_t &stmt) : stmt{stmt} {}
|
||||
PythonStmt(const PythonExpr &e) : stmt{makeStmt(e.expr)} {}
|
||||
operator edsc_stmt_t() { return stmt; }
|
||||
std::string str() {
|
||||
assert(stmt && "unexpected empty stmt");
|
||||
return Stmt(stmt).str();
|
||||
}
|
||||
edsc_stmt_t stmt;
|
||||
};
|
||||
|
||||
struct PythonBlock {
|
||||
PythonBlock() : blk{nullptr} {}
|
||||
PythonBlock(const edsc_block_t &other) : blk{other} {}
|
||||
PythonBlock(const PythonBlock &other) = default;
|
||||
operator edsc_block_t() { return blk; }
|
||||
std::string str() {
|
||||
assert(blk && "unexpected empty block");
|
||||
return StmtBlock(blk).str();
|
||||
}
|
||||
|
||||
PythonBlock set(const py::list &stmts);
|
||||
|
||||
edsc_block_t blk;
|
||||
};
|
||||
|
||||
struct PythonAttribute {
|
||||
PythonAttribute() : attr(nullptr) {}
|
||||
PythonAttribute(const mlir_attr_t &a) : attr(a) {}
|
||||
|
@ -550,12 +488,6 @@ private:
|
|||
std::unordered_map<std::string, PythonAttribute> attrs;
|
||||
};
|
||||
|
||||
struct PythonIndexed : public edsc_indexed_t {
|
||||
PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {}
|
||||
PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {}
|
||||
operator PythonExpr() { return PythonExpr(base); }
|
||||
};
|
||||
|
||||
struct PythonIndexedValue {
|
||||
explicit PythonIndexedValue(PythonType type)
|
||||
: indexed(Type::getFromOpaquePointer(type.type)) {}
|
||||
|
@ -584,57 +516,6 @@ struct PythonIndexedValue {
|
|||
IndexedValue indexed;
|
||||
};
|
||||
|
||||
struct PythonMaxExpr {
|
||||
PythonMaxExpr() : expr(nullptr) {}
|
||||
PythonMaxExpr(const edsc_max_expr_t &e) : expr(e) {}
|
||||
operator edsc_max_expr_t() { return expr; }
|
||||
|
||||
edsc_max_expr_t expr;
|
||||
};
|
||||
|
||||
struct PythonMinExpr {
|
||||
PythonMinExpr() : expr(nullptr) {}
|
||||
PythonMinExpr(const edsc_min_expr_t &e) : expr(e) {}
|
||||
operator edsc_min_expr_t() { return expr; }
|
||||
|
||||
edsc_min_expr_t expr;
|
||||
};
|
||||
|
||||
struct MLIRFunctionEmitter {
|
||||
MLIRFunctionEmitter(PythonFunction f)
|
||||
: currentFunction(reinterpret_cast<mlir::Function *>(f.function)),
|
||||
currentBuilder(currentFunction),
|
||||
emitter(¤tBuilder, currentFunction->getLoc()) {}
|
||||
|
||||
PythonExpr bindConstantBF16(double value);
|
||||
PythonExpr bindConstantF16(float value);
|
||||
PythonExpr bindConstantF32(float value);
|
||||
PythonExpr bindConstantF64(double value);
|
||||
PythonExpr bindConstantInt(int64_t value, unsigned bitwidth);
|
||||
PythonExpr bindConstantIndex(int64_t value);
|
||||
PythonExpr bindConstantFunction(PythonFunction func);
|
||||
PythonExpr bindFunctionArgument(unsigned pos);
|
||||
py::list bindFunctionArguments();
|
||||
py::list bindFunctionArgumentView(unsigned pos);
|
||||
py::list bindMemRefShape(PythonExpr boundMemRef);
|
||||
py::list bindIndexedMemRefShape(PythonIndexed boundMemRef) {
|
||||
return bindMemRefShape(boundMemRef.base);
|
||||
}
|
||||
py::list bindMemRefView(PythonExpr boundMemRef);
|
||||
py::list bindIndexedMemRefView(PythonIndexed boundMemRef) {
|
||||
return bindMemRefView(boundMemRef.base);
|
||||
}
|
||||
void emit(PythonStmt stmt);
|
||||
void emitBlock(PythonBlock block);
|
||||
void emitBlockBody(PythonBlock block);
|
||||
|
||||
private:
|
||||
mlir::Function *currentFunction;
|
||||
mlir::FuncBuilder currentBuilder;
|
||||
mlir::edsc::MLIREmitter emitter;
|
||||
edsc_mlir_emitter_t c_emitter;
|
||||
};
|
||||
|
||||
template <typename ListTy, typename PythonTy, typename Ty>
|
||||
ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
|
||||
for (auto &inp : list) {
|
||||
|
@ -643,120 +524,11 @@ ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
|
|||
return ListTy{owning.data(), owning.size()};
|
||||
}
|
||||
|
||||
static edsc_stmt_list_t makeCStmts(llvm::SmallVectorImpl<edsc_stmt_t> &owning,
|
||||
const py::list &stmts) {
|
||||
return makeCList<edsc_stmt_list_t, PythonStmt>(owning, stmts);
|
||||
}
|
||||
|
||||
static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl<edsc_expr_t> &owning,
|
||||
const py::list &exprs) {
|
||||
return makeCList<edsc_expr_list_t, PythonExpr>(owning, exprs);
|
||||
}
|
||||
|
||||
static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning,
|
||||
const py::list &types) {
|
||||
return makeCList<mlir_type_list_t, PythonType>(owning, types);
|
||||
}
|
||||
|
||||
static edsc_block_list_t
|
||||
makeCBlocks(llvm::SmallVectorImpl<edsc_block_t> &owning,
|
||||
const py::list &blocks) {
|
||||
return makeCList<edsc_block_list_t, PythonBlock>(owning, blocks);
|
||||
}
|
||||
|
||||
PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {}
|
||||
|
||||
PythonExpr MLIRFunctionEmitter::bindConstantBF16(double value) {
|
||||
return ::bindConstantBF16(edsc_mlir_emitter_t{&emitter}, value);
|
||||
}
|
||||
|
||||
PythonExpr MLIRFunctionEmitter::bindConstantF16(float value) {
|
||||
return ::bindConstantF16(edsc_mlir_emitter_t{&emitter}, value);
|
||||
}
|
||||
|
||||
PythonExpr MLIRFunctionEmitter::bindConstantF32(float value) {
|
||||
return ::bindConstantF32(edsc_mlir_emitter_t{&emitter}, value);
|
||||
}
|
||||
|
||||
PythonExpr MLIRFunctionEmitter::bindConstantF64(double value) {
|
||||
return ::bindConstantF64(edsc_mlir_emitter_t{&emitter}, value);
|
||||
}
|
||||
|
||||
PythonExpr MLIRFunctionEmitter::bindConstantInt(int64_t value,
|
||||
unsigned bitwidth) {
|
||||
return ::bindConstantInt(edsc_mlir_emitter_t{&emitter}, value, bitwidth);
|
||||
}
|
||||
|
||||
PythonExpr MLIRFunctionEmitter::bindConstantIndex(int64_t value) {
|
||||
return ::bindConstantIndex(edsc_mlir_emitter_t{&emitter}, value);
|
||||
}
|
||||
|
||||
PythonExpr MLIRFunctionEmitter::bindConstantFunction(PythonFunction func) {
|
||||
return ::bindConstantFunction(edsc_mlir_emitter_t{&emitter}, func);
|
||||
}
|
||||
|
||||
PythonExpr MLIRFunctionEmitter::bindFunctionArgument(unsigned pos) {
|
||||
return ::bindFunctionArgument(edsc_mlir_emitter_t{&emitter},
|
||||
mlir_func_t{currentFunction}, pos);
|
||||
}
|
||||
|
||||
PythonExpr getPythonType(edsc_expr_t e) { return PythonExpr(e); }
|
||||
|
||||
template <typename T> py::list makePyList(llvm::ArrayRef<T> owningResults) {
|
||||
py::list res;
|
||||
for (auto e : owningResults) {
|
||||
res.append(getPythonType(e));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
py::list MLIRFunctionEmitter::bindFunctionArguments() {
|
||||
auto arity = getFunctionArity(mlir_func_t{currentFunction});
|
||||
llvm::SmallVector<edsc_expr_t, 8> owningResults(arity);
|
||||
edsc_expr_list_t results{owningResults.data(), owningResults.size()};
|
||||
::bindFunctionArguments(edsc_mlir_emitter_t{&emitter},
|
||||
mlir_func_t{currentFunction}, &results);
|
||||
return makePyList(ArrayRef<edsc_expr_t>{owningResults});
|
||||
}
|
||||
|
||||
py::list MLIRFunctionEmitter::bindMemRefShape(PythonExpr boundMemRef) {
|
||||
auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef);
|
||||
llvm::SmallVector<edsc_expr_t, 8> owningShapes(rank);
|
||||
edsc_expr_list_t resultShapes{owningShapes.data(), owningShapes.size()};
|
||||
::bindMemRefShape(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultShapes);
|
||||
return makePyList(ArrayRef<edsc_expr_t>{owningShapes});
|
||||
}
|
||||
|
||||
py::list MLIRFunctionEmitter::bindMemRefView(PythonExpr boundMemRef) {
|
||||
auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef);
|
||||
// Own the PythonExpr for the arg as well as all its dims.
|
||||
llvm::SmallVector<edsc_expr_t, 8> owningLbs(rank);
|
||||
llvm::SmallVector<edsc_expr_t, 8> owningUbs(rank);
|
||||
llvm::SmallVector<edsc_expr_t, 8> owningSteps(rank);
|
||||
edsc_expr_list_t resultLbs{owningLbs.data(), owningLbs.size()};
|
||||
edsc_expr_list_t resultUbs{owningUbs.data(), owningUbs.size()};
|
||||
edsc_expr_list_t resultSteps{owningSteps.data(), owningSteps.size()};
|
||||
::bindMemRefView(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultLbs,
|
||||
&resultUbs, &resultSteps);
|
||||
py::list res;
|
||||
res.append(makePyList(ArrayRef<edsc_expr_t>{owningLbs}));
|
||||
res.append(makePyList(ArrayRef<edsc_expr_t>{owningUbs}));
|
||||
res.append(makePyList(ArrayRef<edsc_expr_t>{owningSteps}));
|
||||
return res;
|
||||
}
|
||||
|
||||
void MLIRFunctionEmitter::emit(PythonStmt stmt) {
|
||||
emitter.emitStmt(Stmt(stmt));
|
||||
}
|
||||
|
||||
void MLIRFunctionEmitter::emitBlock(PythonBlock block) {
|
||||
emitter.emitBlock(StmtBlock(block));
|
||||
}
|
||||
|
||||
void MLIRFunctionEmitter::emitBlockBody(PythonBlock block) {
|
||||
emitter.emitStmts(StmtBlock(block).getBody());
|
||||
}
|
||||
|
||||
PythonFunction
|
||||
PythonMLIRModule::declareFunction(const std::string &name,
|
||||
const py::list &inputs,
|
||||
|
@ -811,29 +583,6 @@ PythonMLIRModule::declareFunction(const std::string &name,
|
|||
return func;
|
||||
}
|
||||
|
||||
PythonExpr PythonMLIRModule::op(const std::string &name, PythonType type,
|
||||
const py::list &arguments,
|
||||
const py::list &successors,
|
||||
py::kwargs attributes) {
|
||||
SmallVector<edsc_expr_t, 8> owningExprs;
|
||||
SmallVector<edsc_block_t, 4> owningBlocks;
|
||||
SmallVector<mlir_named_attr_t, 4> owningAttrs;
|
||||
SmallVector<std::string, 4> owningAttrNames;
|
||||
|
||||
owningAttrs.reserve(attributes.size());
|
||||
owningAttrNames.reserve(attributes.size());
|
||||
for (const auto &kvp : attributes) {
|
||||
owningAttrNames.push_back(kvp.first.str());
|
||||
auto value = kvp.second.cast<PythonAttribute>();
|
||||
owningAttrs.push_back({owningAttrNames.back().c_str(), value});
|
||||
}
|
||||
|
||||
return PythonExpr(::Op(mlir_context_t(&mlirContext), name.c_str(), type,
|
||||
makeCExprs(owningExprs, arguments),
|
||||
makeCBlocks(owningBlocks, successors),
|
||||
{owningAttrs.data(), owningAttrs.size()}));
|
||||
}
|
||||
|
||||
PythonAttributedType PythonType::attachAttributeDict(
|
||||
const std::unordered_map<std::string, PythonAttribute> &attrs) const {
|
||||
return PythonAttributedType(*this, attrs);
|
||||
|
@ -847,88 +596,10 @@ PythonAttribute PythonMLIRModule::boolAttr(bool value) {
|
|||
return PythonAttribute(::makeBoolAttr(&mlirContext, value));
|
||||
}
|
||||
|
||||
PythonBlock PythonBlock::set(const py::list &stmts) {
|
||||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
::BlockSetBody(blk, makeCStmts(owning, stmts));
|
||||
return *this;
|
||||
}
|
||||
|
||||
PythonExpr dispatchCall(py::args args, py::kwargs kwargs) {
|
||||
assert(args.size() != 0);
|
||||
llvm::SmallVector<edsc_expr_t, 8> exprs;
|
||||
exprs.reserve(args.size());
|
||||
for (auto arg : args) {
|
||||
exprs.push_back(arg.cast<PythonExpr>());
|
||||
}
|
||||
|
||||
edsc_expr_list_t operands{exprs.data() + 1, exprs.size() - 1};
|
||||
|
||||
if (kwargs && kwargs.contains("result")) {
|
||||
for (const auto &kvp : kwargs) {
|
||||
if (static_cast<std::string>(kvp.first.str()) == "result")
|
||||
return ::Call1(exprs.front(), kvp.second.cast<PythonType>(), operands);
|
||||
}
|
||||
}
|
||||
return ::Call0(exprs.front(), operands);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(pybind, m) {
|
||||
m.doc() =
|
||||
"Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
|
||||
m.def("version", []() { return "EDSC Python extensions v0.0"; });
|
||||
m.def("initContext",
|
||||
[]() { return static_cast<void *>(new ScopedEDSCContext()); });
|
||||
m.def("deleteContext",
|
||||
[](void *ctx) { delete reinterpret_cast<ScopedEDSCContext *>(ctx); });
|
||||
|
||||
m.def("Block", [](const py::list &args, const py::list &stmts) {
|
||||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
SmallVector<edsc_expr_t, 8> owningArgs;
|
||||
return PythonBlock(
|
||||
::Block(makeCExprs(owningArgs, args), makeCStmts(owning, stmts)));
|
||||
});
|
||||
m.def("Block", [](const py::list &stmts) {
|
||||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
edsc_expr_list_t args{nullptr, 0};
|
||||
return PythonBlock(::Block(args, makeCStmts(owning, stmts)));
|
||||
});
|
||||
m.def(
|
||||
"Branch",
|
||||
[](PythonBlock destination, const py::list &operands) {
|
||||
SmallVector<edsc_expr_t, 8> owning;
|
||||
return PythonStmt(::Branch(destination, makeCExprs(owning, operands)));
|
||||
},
|
||||
py::arg("destination"), py::arg("operands") = py::list());
|
||||
m.def("CondBranch",
|
||||
[](PythonExpr condition, PythonBlock trueDestination,
|
||||
const py::list &trueOperands, PythonBlock falseDestination,
|
||||
const py::list &falseOperands) {
|
||||
SmallVector<edsc_expr_t, 8> owningTrue;
|
||||
SmallVector<edsc_expr_t, 8> owningFalse;
|
||||
return PythonStmt(::CondBranch(
|
||||
condition, trueDestination, makeCExprs(owningTrue, trueOperands),
|
||||
falseDestination, makeCExprs(owningFalse, falseOperands)));
|
||||
});
|
||||
m.def("CondBranch", [](PythonExpr condition, PythonBlock trueDestination,
|
||||
PythonBlock falseDestination) {
|
||||
edsc_expr_list_t emptyList;
|
||||
emptyList.exprs = nullptr;
|
||||
emptyList.n = 0;
|
||||
return PythonStmt(::CondBranch(condition, trueDestination, emptyList,
|
||||
falseDestination, emptyList));
|
||||
});
|
||||
m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs,
|
||||
const py::list &steps, const py::list &stmts) {
|
||||
SmallVector<edsc_expr_t, 8> owningIVs;
|
||||
SmallVector<edsc_expr_t, 8> owningLBs;
|
||||
SmallVector<edsc_expr_t, 8> owningUBs;
|
||||
SmallVector<edsc_expr_t, 8> owningSteps;
|
||||
SmallVector<edsc_stmt_t, 8> owningStmts;
|
||||
return PythonStmt(
|
||||
::ForNest(makeCExprs(owningIVs, ivs), makeCExprs(owningLBs, lbs),
|
||||
makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps),
|
||||
makeCStmts(owningStmts, stmts)));
|
||||
});
|
||||
m.def("version", []() { return "EDSC Python extensions v1.0"; });
|
||||
|
||||
py::class_<PythonLoopContext>(
|
||||
m, "LoopContext", "A context for building the body of a 'for' loop")
|
||||
|
@ -1032,68 +703,6 @@ PYBIND11_MODULE(pybind, m) {
|
|||
return ValueHandle::create(name, operandHandles, types, attrs);
|
||||
});
|
||||
|
||||
m.def("Max", [](const py::list &args) {
|
||||
SmallVector<edsc_expr_t, 8> owning;
|
||||
return PythonMaxExpr(::Max(makeCExprs(owning, args)));
|
||||
});
|
||||
m.def("Min", [](const py::list &args) {
|
||||
SmallVector<edsc_expr_t, 8> owning;
|
||||
return PythonMinExpr(::Min(makeCExprs(owning, args)));
|
||||
});
|
||||
m.def("For", [](PythonExpr iv, PythonExpr lb, PythonExpr ub, PythonExpr step,
|
||||
const py::list &stmts) {
|
||||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts)));
|
||||
});
|
||||
m.def("For", [](PythonExpr iv, PythonMaxExpr lb, PythonMinExpr ub,
|
||||
PythonExpr step, const py::list &stmts) {
|
||||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
return PythonStmt(::MaxMinFor(iv, lb, ub, step, makeCStmts(owning, stmts)));
|
||||
});
|
||||
m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) {
|
||||
return PythonExpr(::Select(cond, e1, e2));
|
||||
});
|
||||
m.def("Return", []() {
|
||||
return PythonStmt(::Return(edsc_expr_list_t{nullptr, 0}));
|
||||
});
|
||||
m.def("Return", [](const py::list &returns) {
|
||||
SmallVector<edsc_expr_t, 8> owningExprs;
|
||||
return PythonStmt(::Return(makeCExprs(owningExprs, returns)));
|
||||
});
|
||||
m.def("ConstantInteger", [](PythonType type, int64_t value) {
|
||||
return PythonExpr(::ConstantInteger(type, value));
|
||||
});
|
||||
|
||||
#define DEFINE_PYBIND_BINARY_OP(PYTHON_NAME, C_NAME) \
|
||||
m.def(PYTHON_NAME, [](PythonExpr e1, PythonExpr e2) { \
|
||||
return PythonExpr(::C_NAME(e1, e2)); \
|
||||
});
|
||||
|
||||
DEFINE_PYBIND_BINARY_OP("Add", Add);
|
||||
DEFINE_PYBIND_BINARY_OP("Mul", Mul);
|
||||
DEFINE_PYBIND_BINARY_OP("Sub", Sub);
|
||||
DEFINE_PYBIND_BINARY_OP("Div", Div);
|
||||
DEFINE_PYBIND_BINARY_OP("Rem", Rem);
|
||||
DEFINE_PYBIND_BINARY_OP("FloorDiv", FloorDiv);
|
||||
DEFINE_PYBIND_BINARY_OP("CeilDiv", CeilDiv);
|
||||
DEFINE_PYBIND_BINARY_OP("LT", LT);
|
||||
DEFINE_PYBIND_BINARY_OP("LE", LE);
|
||||
DEFINE_PYBIND_BINARY_OP("GT", GT);
|
||||
DEFINE_PYBIND_BINARY_OP("GE", GE);
|
||||
DEFINE_PYBIND_BINARY_OP("EQ", EQ);
|
||||
DEFINE_PYBIND_BINARY_OP("NE", NE);
|
||||
DEFINE_PYBIND_BINARY_OP("And", And);
|
||||
DEFINE_PYBIND_BINARY_OP("Or", Or);
|
||||
|
||||
#undef DEFINE_PYBIND_BINARY_OP
|
||||
|
||||
#define DEFINE_PYBIND_UNARY_OP(PYTHON_NAME, C_NAME) \
|
||||
m.def(PYTHON_NAME, [](PythonExpr e1) { return PythonExpr(::C_NAME(e1)); });
|
||||
|
||||
DEFINE_PYBIND_UNARY_OP("Negate", Negate);
|
||||
|
||||
#undef DEFINE_PYBIND_UNARY_OP
|
||||
|
||||
py::class_<PythonFunction>(m, "Function",
|
||||
"Wrapping class for mlir::Function.")
|
||||
.def(py::init<PythonFunction>())
|
||||
|
@ -1104,12 +713,6 @@ PYBIND11_MODULE(pybind, m) {
|
|||
.def("arg", &PythonFunction::arg,
|
||||
"Get the ValueHandle to the indexed argument of the function");
|
||||
|
||||
py::class_<PythonBlock>(m, "StmtBlock",
|
||||
"Wrapping class for mlir::edsc::StmtBlock")
|
||||
.def(py::init<PythonBlock>())
|
||||
.def("set", &PythonBlock::set)
|
||||
.def("__str__", &PythonBlock::str);
|
||||
|
||||
py::class_<PythonAttribute>(m, "Attribute",
|
||||
"Wrapping class for mlir::Attribute")
|
||||
.def(py::init<PythonAttribute>())
|
||||
|
@ -1143,9 +746,6 @@ PYBIND11_MODULE(pybind, m) {
|
|||
"directly require integration with a tensor library (e.g. numpy). This "
|
||||
"is left as the prerogative of libraries and frameworks for now.")
|
||||
.def(py::init<>())
|
||||
.def("op", &PythonMLIRModule::op, py::arg("name"), py::arg("type"),
|
||||
py::arg("arguments"), py::arg("successors") = py::list(),
|
||||
"Creates a new expression identified by its canonical name.")
|
||||
.def("boolAttr", &PythonMLIRModule::boolAttr,
|
||||
"Creates an mlir::BoolAttr with the given value")
|
||||
.def(
|
||||
|
@ -1198,15 +798,6 @@ PYBIND11_MODULE(pybind, m) {
|
|||
.def("__str__", &PythonMLIRModule::getIR,
|
||||
"Get the string representation of the module");
|
||||
|
||||
py::class_<ContextManager>(
|
||||
m, "ContextManager",
|
||||
"An EDSC context manager is the memory arena containing all the EDSC "
|
||||
"allocations.\nUsage:\n\n"
|
||||
"with E.ContextManager() as _:\n i = E.Expr(E.Bindable())\n ...")
|
||||
.def(py::init<>())
|
||||
.def("__enter__", &ContextManager::enter)
|
||||
.def("__exit__", &ContextManager::exit);
|
||||
|
||||
py::class_<PythonFunctionContext>(
|
||||
m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext")
|
||||
.def(py::init<PythonFunction>())
|
||||
|
@ -1313,131 +904,6 @@ PYBIND11_MODULE(pybind, m) {
|
|||
.def(py::init<PythonValueHandle>())
|
||||
.def("load", &PythonIndexedValue::load)
|
||||
.def("store", &PythonIndexedValue::store);
|
||||
|
||||
py::class_<MLIRFunctionEmitter>(
|
||||
m, "MLIRFunctionEmitter",
|
||||
"An MLIRFunctionEmitter is used to fill an empty function body. This is "
|
||||
"a staged process:\n"
|
||||
" 1. create or retrieve an mlir::Function `f` with an empty body;\n"
|
||||
" 2. make an `MLIRFunctionEmitter(f)` to build the current function;\n"
|
||||
" 3. create leaf Expr that are either Bindable or already Expr that are"
|
||||
" bound to constants and function arguments by using methods of "
|
||||
" `MLIRFunctionEmitter`;\n"
|
||||
" 4. build the function body using Expr, Indexed and Stmt;\n"
|
||||
" 5. emit the MLIR to implement the function body.")
|
||||
.def(py::init<PythonFunction>())
|
||||
.def("bind_constant_bf16", &MLIRFunctionEmitter::bindConstantBF16)
|
||||
.def("bind_constant_f16", &MLIRFunctionEmitter::bindConstantF16)
|
||||
.def("bind_constant_f32", &MLIRFunctionEmitter::bindConstantF32)
|
||||
.def("bind_constant_f64", &MLIRFunctionEmitter::bindConstantF64)
|
||||
.def("bind_constant_int", &MLIRFunctionEmitter::bindConstantInt)
|
||||
.def("bind_constant_index", &MLIRFunctionEmitter::bindConstantIndex)
|
||||
.def("bind_constant_function", &MLIRFunctionEmitter::bindConstantFunction)
|
||||
.def("bind_function_argument", &MLIRFunctionEmitter::bindFunctionArgument,
|
||||
"Returns an Expr that has been bound to a positional argument in "
|
||||
"the current Function.")
|
||||
.def("bind_function_arguments",
|
||||
&MLIRFunctionEmitter::bindFunctionArguments,
|
||||
"Returns a list of Expr where each Expr has been bound to the "
|
||||
"corresponding positional argument in the current Function.")
|
||||
.def("bind_memref_shape", &MLIRFunctionEmitter::bindMemRefShape,
|
||||
"Returns a list of Expr where each Expr has been bound to the "
|
||||
"corresponding dimension of the memref.")
|
||||
.def("bind_memref_view", &MLIRFunctionEmitter::bindMemRefView,
|
||||
"Returns three lists (lower bound, upper bound and step) of Expr "
|
||||
"where each triplet of Expr has been bound to the minimal offset, "
|
||||
"extent and stride of the corresponding dimension of the memref.")
|
||||
.def("bind_indexed_shape", &MLIRFunctionEmitter::bindIndexedMemRefShape,
|
||||
"Same as bind_memref_shape but returns a list of `Indexed` that "
|
||||
"support load and store operations")
|
||||
.def("bind_indexed_view", &MLIRFunctionEmitter::bindIndexedMemRefView,
|
||||
"Same as bind_memref_view but returns lists of `Indexed` that "
|
||||
"support load and store operations")
|
||||
.def("emit", &MLIRFunctionEmitter::emit,
|
||||
"Emits the MLIR for the EDSC expressions and statements in the "
|
||||
"current function body.")
|
||||
.def("emit", &MLIRFunctionEmitter::emitBlock,
|
||||
"Emits the MLIR for the EDSC statements into a new block")
|
||||
.def("emit_inplace", &MLIRFunctionEmitter::emitBlockBody,
|
||||
"Emits the MLIR for the EDSC statements contained in a EDSC block "
|
||||
"into the current function body without creating a new block");
|
||||
|
||||
py::class_<PythonExpr>(m, "Expr", "Wrapping class for mlir::edsc::Expr")
|
||||
.def(py::init<PythonBindable>())
|
||||
.def("__add__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Add(e1, e2)); })
|
||||
.def("__sub__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Sub(e1, e2)); })
|
||||
.def("__mul__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Mul(e1, e2)); })
|
||||
.def("__div__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Div(e1, e2)); })
|
||||
.def("__truediv__",
|
||||
[](PythonExpr e1, PythonExpr e2) {
|
||||
return PythonExpr(::Div(e1, e2));
|
||||
})
|
||||
.def("__mod__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Rem(e1, e2)); })
|
||||
.def("__lt__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::LT(e1, e2)); })
|
||||
.def("__le__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::LE(e1, e2)); })
|
||||
.def("__gt__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::GT(e1, e2)); })
|
||||
.def("__ge__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::GE(e1, e2)); })
|
||||
.def("__eq__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::EQ(e1, e2)); })
|
||||
.def("__ne__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::NE(e1, e2)); })
|
||||
.def("__and__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::And(e1, e2)); })
|
||||
.def("__or__", [](PythonExpr e1,
|
||||
PythonExpr e2) { return PythonExpr(::Or(e1, e2)); })
|
||||
.def("__invert__", [](PythonExpr e) { return PythonExpr(::Negate(e)); })
|
||||
.def("__call__", &dispatchCall)
|
||||
.def("__str__", &PythonExpr::str,
|
||||
R"DOC(Returns the string value for the Expr)DOC");
|
||||
|
||||
py::class_<PythonBindable>(
|
||||
m, "Bindable",
|
||||
"Wrapping class for mlir::edsc::Bindable.\nA Bindable is a special Expr "
|
||||
"that can be bound manually to specific MLIR SSA Values.")
|
||||
.def(py::init<PythonType>())
|
||||
.def("__str__", &PythonBindable::str);
|
||||
|
||||
py::class_<PythonStmt>(m, "Stmt", "Wrapping class for mlir::edsc::Stmt.")
|
||||
.def(py::init<PythonExpr>())
|
||||
.def("__str__", &PythonStmt::str,
|
||||
R"DOC(Returns the string value for the Expr)DOC");
|
||||
|
||||
py::class_<PythonIndexed>(
|
||||
m, "Indexed",
|
||||
"Wrapping class for mlir::edsc::Indexed.\nAn Indexed is a wrapper class "
|
||||
"that support load and store operations.")
|
||||
.def(py::init<PythonExpr>(), R"DOC(Build from existing Expr)DOC")
|
||||
.def(py::init<PythonBindable>(), R"DOC(Build from existing Bindable)DOC")
|
||||
.def(
|
||||
"load",
|
||||
[](PythonIndexed &instance, const py::list &indices) {
|
||||
SmallVector<edsc_expr_t, 8> owning;
|
||||
return PythonExpr(Load(instance, makeCExprs(owning, indices)));
|
||||
},
|
||||
R"DOC(Returns an Expr that loads from an Indexed)DOC")
|
||||
.def(
|
||||
"store",
|
||||
[](PythonIndexed &instance, const py::list &indices,
|
||||
PythonExpr value) {
|
||||
SmallVector<edsc_expr_t, 8> owning;
|
||||
return PythonStmt(
|
||||
Store(value, instance, makeCExprs(owning, indices)));
|
||||
},
|
||||
R"DOC(Returns the Stmt that stores into an Indexed)DOC");
|
||||
|
||||
py::class_<PythonMaxExpr>(m, "MaxExpr",
|
||||
"Wrapping class for mlir::edsc::MaxExpr");
|
||||
py::class_<PythonMinExpr>(m, "MinExpr",
|
||||
"Wrapping class for mlir::edsc::MinExpr");
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python2 and 3 test for the MLIR EDSC C API and Python bindings"""
|
||||
"""Python2 and 3 test for the MLIR EDSC Python bindings"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -400,212 +400,6 @@ class EdscTest(unittest.TestCase):
|
|||
self.assertIn("%2 = mulf %0, %1 : f32", code)
|
||||
self.assertIn("store %2, %arg2[%i0, %i1] : memref<32x32xf32>", code)
|
||||
|
||||
def testBindables(self):
|
||||
with E.ContextManager():
|
||||
i = E.Expr(E.Bindable(self.i32Type))
|
||||
self.assertIn("$1", i.__str__())
|
||||
|
||||
def testOneExpr(self):
|
||||
with E.ContextManager():
|
||||
i, lb, ub = list(
|
||||
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
|
||||
expr = E.Mul(i, E.Add(lb, ub))
|
||||
str = expr.__str__()
|
||||
self.assertIn("($1 * ($2 + $3))", str)
|
||||
|
||||
def testCustomOp(self):
|
||||
with E.ContextManager():
|
||||
a, b = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2))
|
||||
c1 = self.module.op(
|
||||
"std.constant",
|
||||
self.i32Type, [],
|
||||
value=self.module.integerAttr(self.i32Type, 42))
|
||||
expr = self.module.op("std.addi", self.i32Type, [c1, b])
|
||||
str = expr.__str__()
|
||||
self.assertIn("addi(42, $2)", str)
|
||||
|
||||
def testOneLoop(self):
|
||||
with E.ContextManager():
|
||||
i, lb, ub, step = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))])
|
||||
str = loop.__str__()
|
||||
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
||||
self.assertIn(" = ($2 + $3)", str)
|
||||
|
||||
def testTwoLoops(self):
|
||||
with E.ContextManager():
|
||||
i, lb, ub, step = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])])
|
||||
str = loop.__str__()
|
||||
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
||||
self.assertIn("for($1 = $2 to $3 step $4) {", str)
|
||||
self.assertIn("$5 = $1;", str)
|
||||
|
||||
def testNestedLoops(self):
|
||||
with E.ContextManager():
|
||||
i, lb, ub = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
step = E.ConstantInteger(self.indexType, 42)
|
||||
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
loop = E.For(ivs, lbs, ubs, steps, [
|
||||
E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]),
|
||||
])
|
||||
str = loop.__str__()
|
||||
self.assertIn("for($5 = $9 to $13 step $17) {", str)
|
||||
self.assertIn("for($6 = $10 to $14 step $18) {", str)
|
||||
self.assertIn("for($7 = $11 to $15 step $19) {", str)
|
||||
self.assertIn("for($8 = $12 to $16 step $20) {", str)
|
||||
self.assertIn("for($1 = $2 to $3 step 42) {", str)
|
||||
self.assertIn("= (($3 * 42) + $2 * -1);", str)
|
||||
|
||||
def testMaxMinLoop(self):
|
||||
with E.ContextManager():
|
||||
i = E.Expr(E.Bindable(self.indexType))
|
||||
step = E.Expr(E.Bindable(self.indexType))
|
||||
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
loop = E.For(i, E.Max(lbs), E.Min(ubs), step, [])
|
||||
s = str(loop)
|
||||
self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)",
|
||||
s)
|
||||
|
||||
def testIndexed(self):
|
||||
with E.ContextManager():
|
||||
i, j, k = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
|
||||
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
|
||||
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
|
||||
str = stmt.__str__()
|
||||
self.assertIn(" = std.store(", str)
|
||||
|
||||
def testMatmul(self):
|
||||
with E.ContextManager():
|
||||
ivs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
steps = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
i, j, k = ivs[0], ivs[1], ivs[2]
|
||||
memrefType = self.module.make_memref_type(self.f32Type, [42, 42])
|
||||
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
|
||||
loop = E.For(
|
||||
ivs, lbs, ubs, steps,
|
||||
[C.store([i, j],
|
||||
C.load([i, j]) + A.load([i, k]) * B.load([k, j]))])
|
||||
str = loop.__str__()
|
||||
self.assertIn("for($1 = $4 to $7 step $10) {", str)
|
||||
self.assertIn("for($2 = $5 to $8 step $11) {", str)
|
||||
self.assertIn("for($3 = $6 to $9 step $12) {", str)
|
||||
self.assertIn(" = std.store", str)
|
||||
|
||||
def testArithmetic(self):
|
||||
with E.ContextManager():
|
||||
i, j, k, l = list(
|
||||
map(E.Expr, [E.Bindable(self.f32Type) for _ in range(4)]))
|
||||
stmt = i % j + j * k - l / k
|
||||
str = stmt.__str__()
|
||||
self.assertIn("((($1 % $2) + ($2 * $3)) - ($4 / $3))", str)
|
||||
|
||||
def testBoolean(self):
|
||||
with E.ContextManager():
|
||||
i, j, k, l = list(
|
||||
map(E.Expr, [E.Bindable(self.i32Type) for _ in range(4)]))
|
||||
stmt1 = (i < j) & (j >= k)
|
||||
stmt2 = ~(stmt1 | (k == l))
|
||||
str = stmt2.__str__()
|
||||
# Note that "a | b" is currently implemented as ~(~a && ~b) and "~a" is
|
||||
# currently implemented as "constant 1 - a", which leads to this
|
||||
# expression.
|
||||
self.assertIn(
|
||||
"(1 - (1 - ((1 - (($1 < $2) && ($2 >= $3))) && (1 - ($3 == $4)))))",
|
||||
str)
|
||||
|
||||
def testSelect(self):
|
||||
with E.ContextManager():
|
||||
i, j, k = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(3)]))
|
||||
stmt = E.Select(i > j, i, j)
|
||||
str = stmt.__str__()
|
||||
self.assertIn("select(($1 > $2), $1, $2)", str)
|
||||
|
||||
def testCall(self):
|
||||
with E.ContextManager():
|
||||
module = E.MLIRModule()
|
||||
f32 = module.make_scalar_type("f32")
|
||||
func, arg = [E.Expr(E.Bindable(f32)) for _ in range(2)]
|
||||
code = func(arg, result=f32)
|
||||
self.assertIn("@$1($2)", str(code))
|
||||
|
||||
def testBlock(self):
|
||||
with E.ContextManager():
|
||||
i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)]))
|
||||
stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)])
|
||||
str = stmt.__str__()
|
||||
self.assertIn("^bb", str)
|
||||
self.assertIn(" = ($1 + $2)", str)
|
||||
self.assertIn(" = ($1 - $2)", str)
|
||||
|
||||
def testBlockArgs(self):
|
||||
with E.ContextManager():
|
||||
module = E.MLIRModule()
|
||||
t = module.make_scalar_type("i", 32)
|
||||
i, j = list(map(E.Expr, [E.Bindable(t) for _ in range(2)]))
|
||||
stmt = E.Block([i, j], [E.Stmt(i + j)])
|
||||
str = stmt.__str__()
|
||||
self.assertIn("^bb", str)
|
||||
self.assertIn("($1, $2):", str)
|
||||
self.assertIn("($1 + $2)", str)
|
||||
|
||||
def testBranch(self):
|
||||
with E.ContextManager():
|
||||
i, j = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(2)]))
|
||||
b1 = E.Block([E.Stmt(i + j)])
|
||||
b2 = E.Block([E.Branch(b1)])
|
||||
str1 = b1.__str__()
|
||||
str2 = b2.__str__()
|
||||
self.assertIn("^bb1:\n" + "$4 = ($1 + $2)", str1)
|
||||
self.assertIn("^bb2:\n" + "$6 = br ^bb1", str2)
|
||||
|
||||
def testBranchArgs(self):
|
||||
with E.ContextManager():
|
||||
b1arg, b2arg = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2))
|
||||
# Declare empty blocks with arguments and bind those arguments.
|
||||
b1 = E.Block([b1arg], [])
|
||||
b2 = E.Block([b2arg], [])
|
||||
one = E.ConstantInteger(self.i32Type, 1)
|
||||
# Make blocks branch to each other in a sort of infinite loop.
|
||||
# This checks that the EDSC implementation does not fall into such loop.
|
||||
b1.set([E.Branch(b2, [b1arg + one])])
|
||||
b2.set([E.Branch(b1, [b2arg])])
|
||||
str1 = b1.__str__()
|
||||
str2 = b2.__str__()
|
||||
self.assertIn("^bb1($1):\n" + "$6 = br ^bb2(($1 + 1))", str1)
|
||||
self.assertIn("^bb2($2):\n" + "$8 = br ^bb1($2)", str2)
|
||||
|
||||
def testCondBranch(self):
|
||||
with E.ContextManager():
|
||||
cond = E.Expr(E.Bindable(self.boolType))
|
||||
b1 = E.Block([])
|
||||
b2 = E.Block([])
|
||||
b3 = E.Block([E.CondBranch(cond, b1, b2)])
|
||||
str = b3.__str__()
|
||||
self.assertIn("cond_br($1, ^bb1, ^bb2)", str)
|
||||
|
||||
def testCondBranchArgs(self):
|
||||
with E.ContextManager():
|
||||
arg1, arg2, arg3 = (E.Expr(E.Bindable(self.i32Type)) for _ in range(3))
|
||||
expr1, expr2, expr3 = (E.Expr(E.Bindable(self.i32Type)) for _ in range(3))
|
||||
cond = E.Expr(E.Bindable(self.boolType))
|
||||
b1 = E.Block([arg1], [])
|
||||
b2 = E.Block([arg2, arg3], [])
|
||||
b3 = E.Block([E.CondBranch(cond, b1, [expr1], b2, [expr2, expr3])])
|
||||
str = b3.__str__()
|
||||
self.assertIn("cond_br($7, ^bb1($4), ^bb2($5, $6))", str)
|
||||
|
||||
def testMLIRScalarTypes(self):
|
||||
module = E.MLIRModule()
|
||||
t = module.make_scalar_type("bf16")
|
||||
|
@ -641,50 +435,6 @@ class EdscTest(unittest.TestCase):
|
|||
f = module.make_function("sqrtf", [t], [t])
|
||||
self.assertIn("func @sqrtf(%arg0: f32) -> f32", f.__str__())
|
||||
|
||||
def testMLIRConstantEmission(self):
|
||||
module = E.MLIRModule()
|
||||
f = module.make_function("constants", [], [])
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
emitter.bind_constant_bf16(1.23)
|
||||
emitter.bind_constant_f16(1.23)
|
||||
emitter.bind_constant_f32(1.23)
|
||||
emitter.bind_constant_f64(1.23)
|
||||
emitter.bind_constant_int(1, 1)
|
||||
emitter.bind_constant_int(123, 8)
|
||||
emitter.bind_constant_int(123, 16)
|
||||
emitter.bind_constant_int(123, 32)
|
||||
emitter.bind_constant_index(123)
|
||||
emitter.bind_constant_function(f)
|
||||
str = f.__str__()
|
||||
self.assertIn("constant 1.230000e+00 : bf16", str)
|
||||
self.assertIn("constant 1.230470e+00 : f16", str)
|
||||
self.assertIn("constant 1.230000e+00 : f32", str)
|
||||
self.assertIn("constant 1.230000e+00 : f64", str)
|
||||
self.assertIn("constant 1 : i1", str)
|
||||
self.assertIn("constant 123 : i8", str)
|
||||
self.assertIn("constant 123 : i16", str)
|
||||
self.assertIn("constant 123 : i32", str)
|
||||
self.assertIn("constant 123 : index", str)
|
||||
self.assertIn("constant @constants : () -> ()", str)
|
||||
|
||||
def testMLIRBuiltinEmission(self):
|
||||
module = E.MLIRModule()
|
||||
m = module.make_memref_type(self.f32Type, [10]) # f32 tensor
|
||||
f = module.make_function("call_builtin", [m, m], [])
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
||||
fn = module.declare_function("sqrtf", [self.f32Type], [self.f32Type])
|
||||
fn = emitter.bind_constant_function(fn)
|
||||
zero = emitter.bind_constant_index(0)
|
||||
emitter.emit_inplace(E.Block([
|
||||
output.store([zero], fn(input.load([zero]), result=self.f32Type))
|
||||
]))
|
||||
str = f.__str__()
|
||||
self.assertIn("%f = constant @sqrtf : (f32) -> f32", str)
|
||||
self.assertIn("call_indirect %f(%0) : (f32) -> f32", str)
|
||||
|
||||
def testFunctionDeclaration(self):
|
||||
module = E.MLIRModule()
|
||||
boolAttr = self.module.boolAttr(True)
|
||||
|
@ -697,99 +447,35 @@ class EdscTest(unittest.TestCase):
|
|||
"func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias: true}, memref<10xf32> {readonly: true})",
|
||||
str)
|
||||
|
||||
def testMLIRBooleanEmission(self):
|
||||
def testMLIRBooleanCompilation(self):
|
||||
m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
|
||||
f = self.module.make_function("mkbooltensor", [m, m], [])
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
||||
i = E.Expr(E.Bindable(self.indexType))
|
||||
j = E.Expr(E.Bindable(self.indexType))
|
||||
k = E.Expr(E.Bindable(self.indexType))
|
||||
idxs = [i, j, k]
|
||||
zero = emitter.bind_constant_index(0)
|
||||
one = emitter.bind_constant_index(1)
|
||||
ten = emitter.bind_constant_index(10)
|
||||
b1 = E.And(i < j, j < k)
|
||||
b2 = E.Negate(b1)
|
||||
b3 = E.Or(b2, k < j)
|
||||
loop = E.Block([
|
||||
E.For(idxs, [zero]*3, [ten]*3, [one]*3, [
|
||||
output.store([i], E.And(input.load([i]), b3))
|
||||
]),
|
||||
E.Return()
|
||||
])
|
||||
emitter.emit_inplace(loop)
|
||||
# str = f.__str__()
|
||||
# print(str)
|
||||
with self.module.function_context("mkbooltensor", [m, m], []) as f:
|
||||
input = E.IndexedValue(f.arg(0))
|
||||
output = E.IndexedValue(f.arg(1))
|
||||
zero = E.constant_index(0)
|
||||
ten = E.constant_index(10)
|
||||
with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k):
|
||||
b1 = (i < j) & (j < k)
|
||||
b2 = ~b1
|
||||
b3 = b2 | (k < j)
|
||||
output.store([i], input.load([i]) & b3)
|
||||
E.ret([])
|
||||
|
||||
self.module.compile()
|
||||
self.assertNotEqual(self.module.get_engine_address(), 0)
|
||||
|
||||
def testCustomOpEmission(self):
|
||||
f = self.module.make_function("fooer", [self.i32Type, self.i32Type], [])
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
funcArg1, funcArg2 = emitter.bind_function_arguments()
|
||||
boolAttr = self.module.boolAttr(True)
|
||||
expr = self.module.op(
|
||||
"foo", self.i32Type, [funcArg1, funcArg2], attr=boolAttr)
|
||||
block = E.Block([E.Stmt(expr), E.Return()])
|
||||
emitter.emit_inplace(block)
|
||||
|
||||
code = str(f)
|
||||
self.assertIn('%0 = "foo"(%arg0, %arg1) {attr: true} : (i32, i32) -> i32',
|
||||
code)
|
||||
|
||||
# Create 'addi' using the generic Op interface. We need an operation known
|
||||
# to the execution engine so that the engine can compile it.
|
||||
def testCustomOpCompilation(self):
|
||||
f = self.module.make_function("adder", [self.i32Type], [])
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
funcArg, = emitter.bind_function_arguments()
|
||||
c1 = self.module.op(
|
||||
"std.constant",
|
||||
self.i32Type, [],
|
||||
with self.module.function_context("adder", [self.i32Type], []) as f:
|
||||
c1 = E.op(
|
||||
"std.constant", [], [self.i32Type],
|
||||
value=self.module.integerAttr(self.i32Type, 42))
|
||||
expr = self.module.op("std.addi", self.i32Type, [c1, funcArg])
|
||||
block = E.Block([E.Stmt(expr), E.Return()])
|
||||
emitter.emit_inplace(block)
|
||||
self.module.compile()
|
||||
self.assertNotEqual(self.module.get_engine_address(), 0)
|
||||
E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
|
||||
E.ret([])
|
||||
|
||||
|
||||
def testMLIREmission(self):
|
||||
shape = [3, 4, 5]
|
||||
m = self.module.make_memref_type(self.f32Type, shape)
|
||||
f = self.module.make_function("copy", [m, m], [])
|
||||
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
zero = emitter.bind_constant_index(0)
|
||||
one = emitter.bind_constant_index(1)
|
||||
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
||||
M, N, O = emitter.bind_indexed_shape(input)
|
||||
|
||||
ivs = list(
|
||||
map(E.Expr, [E.Bindable(self.indexType) for _ in range(len(shape))]))
|
||||
lbs = [zero, zero, zero]
|
||||
ubs = [M, N, O]
|
||||
steps = [one, one, one]
|
||||
|
||||
# TODO(ntv): emitter.assertEqual(M, oM) etc
|
||||
loop = E.Block([
|
||||
E.For(ivs, lbs, ubs, steps, [output.store(ivs, input.load(ivs))]),
|
||||
E.Return()
|
||||
])
|
||||
emitter.emit_inplace(loop)
|
||||
|
||||
# print(f) # uncomment to see the emitted IR
|
||||
str = f.__str__()
|
||||
self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""",
|
||||
str)
|
||||
|
||||
self.module.compile()
|
||||
self.assertNotEqual(self.module.get_engine_address(), 0)
|
||||
self.module.compile()
|
||||
self.assertNotEqual(self.module.get_engine_address(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright 2019 The MLIR Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python3 test for the MLIR EDSC C API and Python bindings"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import google_mlir.bindings.python.pybind as E
|
||||
|
||||
|
||||
class EdscTest(unittest.TestCase):
|
||||
|
||||
def testSugaredMLIREmission(self):
|
||||
shape = [3, 4, 5, 6, 7]
|
||||
shape_t = [7, 4, 5, 6, 3]
|
||||
module = E.MLIRModule()
|
||||
t = module.make_scalar_type("f32")
|
||||
m = module.make_memref_type(t, shape)
|
||||
m_t = module.make_memref_type(t, shape_t)
|
||||
f = module.make_function("copy", [m, m_t], [])
|
||||
|
||||
with E.ContextManager():
|
||||
emitter = E.MLIRFunctionEmitter(f)
|
||||
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
|
||||
lbs, ubs, steps = emitter.bind_indexed_view(input)
|
||||
i, *ivs, j = list(
|
||||
map(E.Expr,
|
||||
[E.Bindable(module.make_index_type()) for _ in range(len(shape))
|
||||
]))
|
||||
|
||||
# n-D type and rank agnostic copy-transpose-first-last (where n >= 2).
|
||||
loop = E.Block([
|
||||
E.For([i, *ivs, j], lbs, ubs, steps,
|
||||
[output.store([i, *ivs, j], input.load([j, *ivs, i]))]),
|
||||
E.Return()
|
||||
])
|
||||
emitter.emit_inplace(loop)
|
||||
|
||||
# print(f) # uncomment to see the emitted IR
|
||||
str = f.__str__()
|
||||
self.assertIn("load %arg0[%i4, %i1, %i2, %i3, %i0]", str)
|
||||
self.assertIn("store %0, %arg1[%i0, %i1, %i2, %i3, %i4]", str)
|
||||
|
||||
module.compile()
|
||||
self.assertNotEqual(module.get_engine_address(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue