forked from OSchip/llvm-project
Add mlir python APIs for creating operations, regions and blocks.
* The API is a bit more verbose than I feel like it needs to be. In a follow-up I'd like to abbreviate some things and look in to creating aliases for common accessors. * There is a lingering lifetime hazard between the module and newly added operations. We have the facilities now to solve for this but I will do that in a follow-up. * We may need to craft a more limited API for safely referencing successors when creating operations. We need more facilities to really prove that out and should defer for now. Differential Revision: https://reviews.llvm.org/D87996
This commit is contained in:
parent
4cf754c4bc
commit
c1ded6a759
|
@ -24,6 +24,22 @@ using llvm::SmallVector;
|
|||
// Docstrings (trivial, non-duplicated docstrings are included inline).
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
static const char kContextCreateOperationDocstring[] =
|
||||
R"(Creates a new operation.
|
||||
|
||||
Args:
|
||||
name: Operation name (e.g. "dialect.operation").
|
||||
location: A Location object.
|
||||
results: Sequence of Type representing op result types.
|
||||
attributes: Dict of str:Attribute.
|
||||
successors: List of Block for the operation's successors.
|
||||
regions: Number of regions to create.
|
||||
|
||||
Returns:
|
||||
A new "detached" Operation object. Detached operations can be added
|
||||
to blocks, which causes them to become "attached."
|
||||
)";
|
||||
|
||||
static const char kContextParseDocstring[] =
|
||||
R"(Parses a module's assembly format from a string.
|
||||
|
||||
|
@ -60,6 +76,13 @@ static const char kTypeStrDunderDocstring[] =
|
|||
static const char kDumpDocstring[] =
|
||||
R"(Dumps a debug representation of the object to stderr.)";
|
||||
|
||||
static const char kAppendBlockDocstring[] =
|
||||
R"(Appends a new block, with argument types as positional args.
|
||||
|
||||
Returns:
|
||||
The created block.
|
||||
)";
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Conversion utilities.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -265,11 +288,25 @@ public:
|
|||
throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
|
||||
}
|
||||
|
||||
PyBlock appendBlock(py::args pyArgTypes) {
|
||||
operation->checkValid();
|
||||
llvm::SmallVector<MlirType, 4> argTypes;
|
||||
argTypes.reserve(pyArgTypes.size());
|
||||
for (auto &pyArg : pyArgTypes) {
|
||||
argTypes.push_back(pyArg.cast<PyType &>().type);
|
||||
}
|
||||
|
||||
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
|
||||
mlirRegionAppendOwnedBlock(region, block);
|
||||
return PyBlock(operation, block);
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyBlockList>(m, "BlockList")
|
||||
.def("__getitem__", &PyBlockList::dunderGetItem)
|
||||
.def("__iter__", &PyBlockList::dunderIter)
|
||||
.def("__len__", &PyBlockList::dunderLen);
|
||||
.def("__len__", &PyBlockList::dunderLen)
|
||||
.def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -352,11 +389,41 @@ public:
|
|||
"attempt to access out of bounds operation");
|
||||
}
|
||||
|
||||
void insert(int index, PyOperation &newOperation) {
|
||||
parentOperation->checkValid();
|
||||
newOperation.checkValid();
|
||||
if (index < 0) {
|
||||
throw SetPyError(
|
||||
PyExc_IndexError,
|
||||
"only positive insertion indices are supported for operations");
|
||||
}
|
||||
if (newOperation.isAttached()) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"attempt to insert an operation that has already been inserted");
|
||||
}
|
||||
// TODO: Needing to do this check is unfortunate, especially since it will
|
||||
// be a forward-scan, just like the following call to
|
||||
// mlirBlockInsertOwnedOperation. Switch to insert before/after once
|
||||
// D88148 lands.
|
||||
if (index > dunderLen()) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to insert operation past end");
|
||||
}
|
||||
mlirBlockInsertOwnedOperation(block, index, newOperation.get());
|
||||
newOperation.setAttached();
|
||||
// TODO: Rework the parentKeepAlive so as to avoid ownership hazards under
|
||||
// the new ownership.
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyOperationList>(m, "OperationList")
|
||||
.def("__getitem__", &PyOperationList::dunderGetItem)
|
||||
.def("__iter__", &PyOperationList::dunderIter)
|
||||
.def("__len__", &PyOperationList::dunderLen);
|
||||
.def("__len__", &PyOperationList::dunderLen)
|
||||
.def("insert", &PyOperationList::insert, py::arg("index"),
|
||||
py::arg("operation"),
|
||||
"Inserts an operation at an indexed position");
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -416,6 +483,87 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
|
|||
|
||||
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
|
||||
|
||||
py::object PyMlirContext::createOperation(
|
||||
std::string name, PyLocation location,
|
||||
llvm::Optional<std::vector<PyType *>> results,
|
||||
llvm::Optional<py::dict> attributes,
|
||||
llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
|
||||
llvm::SmallVector<MlirType, 4> mlirResults;
|
||||
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
||||
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
|
||||
|
||||
// General parameter validation.
|
||||
if (regions < 0)
|
||||
throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
|
||||
|
||||
// Unpack/validate results.
|
||||
if (results) {
|
||||
mlirResults.reserve(results->size());
|
||||
for (PyType *result : *results) {
|
||||
// TODO: Verify result type originate from the same context.
|
||||
if (!result)
|
||||
throw SetPyError(PyExc_ValueError, "result type cannot be None");
|
||||
mlirResults.push_back(result->type);
|
||||
}
|
||||
}
|
||||
// Unpack/validate attributes.
|
||||
if (attributes) {
|
||||
mlirAttributes.reserve(attributes->size());
|
||||
for (auto &it : *attributes) {
|
||||
|
||||
auto name = it.first.cast<std::string>();
|
||||
auto &attribute = it.second.cast<PyAttribute &>();
|
||||
// TODO: Verify attribute originates from the same context.
|
||||
mlirAttributes.emplace_back(std::move(name), attribute.attr);
|
||||
}
|
||||
}
|
||||
// Unpack/validate successors.
|
||||
if (successors) {
|
||||
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
||||
mlirSuccessors.reserve(successors->size());
|
||||
for (auto *successor : *successors) {
|
||||
// TODO: Verify successor originate from the same context.
|
||||
if (!successor)
|
||||
throw SetPyError(PyExc_ValueError, "successor block cannot be None");
|
||||
mlirSuccessors.push_back(successor->get());
|
||||
}
|
||||
}
|
||||
|
||||
// Apply unpacked/validated to the operation state. Beyond this
|
||||
// point, exceptions cannot be thrown or else the state will leak.
|
||||
MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
|
||||
if (!mlirResults.empty())
|
||||
mlirOperationStateAddResults(&state, mlirResults.size(),
|
||||
mlirResults.data());
|
||||
if (!mlirAttributes.empty()) {
|
||||
// Note that the attribute names directly reference bytes in
|
||||
// mlirAttributes, so that vector must not be changed from here
|
||||
// on.
|
||||
llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
|
||||
mlirNamedAttributes.reserve(mlirAttributes.size());
|
||||
for (auto &it : mlirAttributes)
|
||||
mlirNamedAttributes.push_back(
|
||||
mlirNamedAttributeGet(it.first.c_str(), it.second));
|
||||
mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
|
||||
mlirNamedAttributes.data());
|
||||
}
|
||||
if (!mlirSuccessors.empty())
|
||||
mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
|
||||
mlirSuccessors.data());
|
||||
if (regions) {
|
||||
llvm::SmallVector<MlirRegion, 4> mlirRegions;
|
||||
mlirRegions.resize(regions);
|
||||
for (int i = 0; i < regions; ++i)
|
||||
mlirRegions[i] = mlirRegionCreate();
|
||||
mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
|
||||
mlirRegions.data());
|
||||
}
|
||||
|
||||
// Construct the operation.
|
||||
MlirOperation operation = mlirOperationCreate(&state);
|
||||
return PyOperation::createDetached(getRef(), operation).releaseObject();
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyModule
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -1153,6 +1301,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
[](PyMlirContext &self, bool value) {
|
||||
mlirContextSetAllowUnregisteredDialects(self.get(), value);
|
||||
})
|
||||
.def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
|
||||
py::arg("location"), py::arg("results") = py::none(),
|
||||
py::arg("attributes") = py::none(),
|
||||
py::arg("successors") = py::none(), py::arg("regions") = 0,
|
||||
kContextCreateOperationDocstring)
|
||||
.def(
|
||||
"parse_module",
|
||||
[](PyMlirContext &self, const std::string moduleAsm) {
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
class PyBlock;
|
||||
class PyLocation;
|
||||
class PyMlirContext;
|
||||
class PyModule;
|
||||
class PyOperation;
|
||||
class PyType;
|
||||
|
||||
/// Template for a reference to a concrete type which captures a python
|
||||
/// reference to its underlying python object.
|
||||
|
@ -112,6 +115,14 @@ public:
|
|||
/// Used for testing.
|
||||
size_t getLiveOperationCount();
|
||||
|
||||
/// Creates an operation. See corresponding python docstring.
|
||||
pybind11::object
|
||||
createOperation(std::string name, PyLocation location,
|
||||
llvm::Optional<std::vector<PyType *>> results,
|
||||
llvm::Optional<pybind11::dict> attributes,
|
||||
llvm::Optional<std::vector<PyBlock *>> successors,
|
||||
int regions);
|
||||
|
||||
private:
|
||||
PyMlirContext(MlirContext context);
|
||||
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
|
||||
|
@ -227,6 +238,10 @@ public:
|
|||
}
|
||||
|
||||
bool isAttached() { return attached; }
|
||||
void setAttached() {
|
||||
assert(!attached && "operation already attached");
|
||||
attached = true;
|
||||
}
|
||||
void checkValid();
|
||||
|
||||
private:
|
||||
|
|
|
@ -12,8 +12,16 @@
|
|||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
namespace pybind11 {
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
|
|
|
@ -99,3 +99,99 @@ def testTraverseOpRegionBlockIndices():
|
|||
walk_operations("", module.operation)
|
||||
|
||||
run(testTraverseOpRegionBlockIndices)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDetachedOperation
|
||||
def testDetachedOperation():
|
||||
ctx = mlir.ir.Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
loc = ctx.get_unknown_location()
|
||||
i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
|
||||
op1 = ctx.create_operation(
|
||||
"custom.op1", loc, results=[i32, i32], regions=1, attributes={
|
||||
"foo": mlir.ir.StringAttr.get(ctx, "foo_value"),
|
||||
"bar": mlir.ir.StringAttr.get(ctx, "bar_value"),
|
||||
})
|
||||
# CHECK: %0:2 = "custom.op1"() ( {
|
||||
# CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
|
||||
print(op1)
|
||||
|
||||
# TODO: Check successors once enough infra exists to do it properly.
|
||||
|
||||
run(testDetachedOperation)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationInsert
|
||||
def testOperationInsert():
|
||||
ctx = mlir.ir.Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = ctx.parse_module(r"""
|
||||
func @f1(%arg0: i32) -> i32 {
|
||||
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
||||
return %1 : i32
|
||||
}
|
||||
""")
|
||||
|
||||
# Create test op.
|
||||
loc = ctx.get_unknown_location()
|
||||
op1 = ctx.create_operation("custom.op1", loc)
|
||||
op2 = ctx.create_operation("custom.op2", loc)
|
||||
|
||||
func = module.operation.regions[0].blocks[0].operations[0]
|
||||
entry_block = func.regions[0].blocks[0]
|
||||
entry_block.operations.insert(0, op1)
|
||||
entry_block.operations.insert(1, op2)
|
||||
# CHECK: func @f1
|
||||
# CHECK: "custom.op1"()
|
||||
# CHECK: "custom.op2"()
|
||||
# CHECK: %0 = "custom.addi"
|
||||
print(module)
|
||||
|
||||
# Trying to add a previously added op should raise.
|
||||
try:
|
||||
entry_block.operations.insert(0, op1)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert False, "expected insert of attached op to raise"
|
||||
|
||||
run(testOperationInsert)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationWithRegion
|
||||
def testOperationWithRegion():
|
||||
ctx = mlir.ir.Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
loc = ctx.get_unknown_location()
|
||||
i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
|
||||
op1 = ctx.create_operation("custom.op1", loc, regions=1)
|
||||
block = op1.regions[0].blocks.append(i32, i32)
|
||||
# CHECK: "custom.op1"() ( {
|
||||
# CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors
|
||||
# CHECK: "custom.terminator"() : () -> ()
|
||||
# CHECK: }) : () -> ()
|
||||
terminator = ctx.create_operation("custom.terminator", loc)
|
||||
block.operations.insert(0, terminator)
|
||||
print(op1)
|
||||
|
||||
# Now add the whole operation to another op.
|
||||
# TODO: Verify lifetime hazard by nulling out the new owning module and
|
||||
# accessing op1.
|
||||
# TODO: Also verify accessing the terminator once both parents are nulled
|
||||
# out.
|
||||
module = ctx.parse_module(r"""
|
||||
func @f1(%arg0: i32) -> i32 {
|
||||
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
||||
return %1 : i32
|
||||
}
|
||||
""")
|
||||
func = module.operation.regions[0].blocks[0].operations[0]
|
||||
entry_block = func.regions[0].blocks[0]
|
||||
entry_block.operations.insert(0, op1)
|
||||
# CHECK: func @f1
|
||||
# CHECK: "custom.op1"()
|
||||
# CHECK: "custom.terminator"
|
||||
# CHECK: %0 = "custom.addi"
|
||||
print(module)
|
||||
|
||||
run(testOperationWithRegion)
|
||||
|
|
Loading…
Reference in New Issue