Add mlir python APIs for creating operations, regions and blocks.

* The API is a bit more verbose than I feel like it needs to be. In a follow-up I'd like to abbreviate some things and look in to creating aliases for common accessors.
* There is a lingering lifetime hazard between the module and newly added operations. We have the facilities now to solve for this but I will do that in a follow-up.
* We may need to craft a more limited API for safely referencing successors when creating operations. We need more facilities to really prove that out and should defer for now.

Differential Revision: https://reviews.llvm.org/D87996
This commit is contained in:
Stella Laurenzo 2020-09-20 21:25:46 -07:00
parent 4cf754c4bc
commit c1ded6a759
4 changed files with 274 additions and 2 deletions

View File

@ -24,6 +24,22 @@ using llvm::SmallVector;
// Docstrings (trivial, non-duplicated docstrings are included inline).
//------------------------------------------------------------------------------
static const char kContextCreateOperationDocstring[] =
R"(Creates a new operation.
Args:
name: Operation name (e.g. "dialect.operation").
location: A Location object.
results: Sequence of Type representing op result types.
attributes: Dict of str:Attribute.
successors: List of Block for the operation's successors.
regions: Number of regions to create.
Returns:
A new "detached" Operation object. Detached operations can be added
to blocks, which causes them to become "attached."
)";
static const char kContextParseDocstring[] =
R"(Parses a module's assembly format from a string.
@ -60,6 +76,13 @@ static const char kTypeStrDunderDocstring[] =
static const char kDumpDocstring[] =
R"(Dumps a debug representation of the object to stderr.)";
static const char kAppendBlockDocstring[] =
R"(Appends a new block, with argument types as positional args.
Returns:
The created block.
)";
//------------------------------------------------------------------------------
// Conversion utilities.
//------------------------------------------------------------------------------
@ -265,11 +288,25 @@ public:
throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
}
PyBlock appendBlock(py::args pyArgTypes) {
operation->checkValid();
llvm::SmallVector<MlirType, 4> argTypes;
argTypes.reserve(pyArgTypes.size());
for (auto &pyArg : pyArgTypes) {
argTypes.push_back(pyArg.cast<PyType &>().type);
}
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
mlirRegionAppendOwnedBlock(region, block);
return PyBlock(operation, block);
}
static void bind(py::module &m) {
py::class_<PyBlockList>(m, "BlockList")
.def("__getitem__", &PyBlockList::dunderGetItem)
.def("__iter__", &PyBlockList::dunderIter)
.def("__len__", &PyBlockList::dunderLen);
.def("__len__", &PyBlockList::dunderLen)
.def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
}
private:
@ -352,11 +389,41 @@ public:
"attempt to access out of bounds operation");
}
void insert(int index, PyOperation &newOperation) {
parentOperation->checkValid();
newOperation.checkValid();
if (index < 0) {
throw SetPyError(
PyExc_IndexError,
"only positive insertion indices are supported for operations");
}
if (newOperation.isAttached()) {
throw SetPyError(
PyExc_ValueError,
"attempt to insert an operation that has already been inserted");
}
// TODO: Needing to do this check is unfortunate, especially since it will
// be a forward-scan, just like the following call to
// mlirBlockInsertOwnedOperation. Switch to insert before/after once
// D88148 lands.
if (index > dunderLen()) {
throw SetPyError(PyExc_IndexError,
"attempt to insert operation past end");
}
mlirBlockInsertOwnedOperation(block, index, newOperation.get());
newOperation.setAttached();
// TODO: Rework the parentKeepAlive so as to avoid ownership hazards under
// the new ownership.
}
static void bind(py::module &m) {
py::class_<PyOperationList>(m, "OperationList")
.def("__getitem__", &PyOperationList::dunderGetItem)
.def("__iter__", &PyOperationList::dunderIter)
.def("__len__", &PyOperationList::dunderLen);
.def("__len__", &PyOperationList::dunderLen)
.def("insert", &PyOperationList::insert, py::arg("index"),
py::arg("operation"),
"Inserts an operation at an indexed position");
}
private:
@ -416,6 +483,87 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
py::object PyMlirContext::createOperation(
std::string name, PyLocation location,
llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<py::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
// General parameter validation.
if (regions < 0)
throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
// Unpack/validate results.
if (results) {
mlirResults.reserve(results->size());
for (PyType *result : *results) {
// TODO: Verify result type originate from the same context.
if (!result)
throw SetPyError(PyExc_ValueError, "result type cannot be None");
mlirResults.push_back(result->type);
}
}
// Unpack/validate attributes.
if (attributes) {
mlirAttributes.reserve(attributes->size());
for (auto &it : *attributes) {
auto name = it.first.cast<std::string>();
auto &attribute = it.second.cast<PyAttribute &>();
// TODO: Verify attribute originates from the same context.
mlirAttributes.emplace_back(std::move(name), attribute.attr);
}
}
// Unpack/validate successors.
if (successors) {
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
mlirSuccessors.reserve(successors->size());
for (auto *successor : *successors) {
// TODO: Verify successor originate from the same context.
if (!successor)
throw SetPyError(PyExc_ValueError, "successor block cannot be None");
mlirSuccessors.push_back(successor->get());
}
}
// Apply unpacked/validated to the operation state. Beyond this
// point, exceptions cannot be thrown or else the state will leak.
MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
if (!mlirResults.empty())
mlirOperationStateAddResults(&state, mlirResults.size(),
mlirResults.data());
if (!mlirAttributes.empty()) {
// Note that the attribute names directly reference bytes in
// mlirAttributes, so that vector must not be changed from here
// on.
llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
mlirNamedAttributes.reserve(mlirAttributes.size());
for (auto &it : mlirAttributes)
mlirNamedAttributes.push_back(
mlirNamedAttributeGet(it.first.c_str(), it.second));
mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
mlirNamedAttributes.data());
}
if (!mlirSuccessors.empty())
mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
mlirSuccessors.data());
if (regions) {
llvm::SmallVector<MlirRegion, 4> mlirRegions;
mlirRegions.resize(regions);
for (int i = 0; i < regions; ++i)
mlirRegions[i] = mlirRegionCreate();
mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
mlirRegions.data());
}
// Construct the operation.
MlirOperation operation = mlirOperationCreate(&state);
return PyOperation::createDetached(getRef(), operation).releaseObject();
}
//------------------------------------------------------------------------------
// PyModule
//------------------------------------------------------------------------------
@ -1153,6 +1301,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
[](PyMlirContext &self, bool value) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
})
.def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
py::arg("location"), py::arg("results") = py::none(),
py::arg("attributes") = py::none(),
py::arg("successors") = py::none(), py::arg("regions") = 0,
kContextCreateOperationDocstring)
.def(
"parse_module",
[](PyMlirContext &self, const std::string moduleAsm) {

View File

@ -17,9 +17,12 @@
namespace mlir {
namespace python {
class PyBlock;
class PyLocation;
class PyMlirContext;
class PyModule;
class PyOperation;
class PyType;
/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
@ -112,6 +115,14 @@ public:
/// Used for testing.
size_t getLiveOperationCount();
/// Creates an operation. See corresponding python docstring.
pybind11::object
createOperation(std::string name, PyLocation location,
llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<pybind11::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors,
int regions);
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
@ -227,6 +238,10 @@ public:
}
bool isAttached() { return attached; }
void setAttached() {
assert(!attached && "operation already attached");
attached = true;
}
void checkValid();
private:

View File

@ -12,8 +12,16 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Twine.h"
namespace pybind11 {
namespace detail {
template <typename T>
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
} // namespace detail
} // namespace pybind11
namespace mlir {
namespace python {

View File

@ -99,3 +99,99 @@ def testTraverseOpRegionBlockIndices():
walk_operations("", module.operation)
run(testTraverseOpRegionBlockIndices)
# CHECK-LABEL: TEST: testDetachedOperation
def testDetachedOperation():
ctx = mlir.ir.Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
op1 = ctx.create_operation(
"custom.op1", loc, results=[i32, i32], regions=1, attributes={
"foo": mlir.ir.StringAttr.get(ctx, "foo_value"),
"bar": mlir.ir.StringAttr.get(ctx, "bar_value"),
})
# CHECK: %0:2 = "custom.op1"() ( {
# CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
print(op1)
# TODO: Check successors once enough infra exists to do it properly.
run(testDetachedOperation)
# CHECK-LABEL: TEST: testOperationInsert
def testOperationInsert():
ctx = mlir.ir.Context()
ctx.allow_unregistered_dialects = True
module = ctx.parse_module(r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""")
# Create test op.
loc = ctx.get_unknown_location()
op1 = ctx.create_operation("custom.op1", loc)
op2 = ctx.create_operation("custom.op2", loc)
func = module.operation.regions[0].blocks[0].operations[0]
entry_block = func.regions[0].blocks[0]
entry_block.operations.insert(0, op1)
entry_block.operations.insert(1, op2)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.op2"()
# CHECK: %0 = "custom.addi"
print(module)
# Trying to add a previously added op should raise.
try:
entry_block.operations.insert(0, op1)
except ValueError:
pass
else:
assert False, "expected insert of attached op to raise"
run(testOperationInsert)
# CHECK-LABEL: TEST: testOperationWithRegion
def testOperationWithRegion():
ctx = mlir.ir.Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
op1 = ctx.create_operation("custom.op1", loc, regions=1)
block = op1.regions[0].blocks.append(i32, i32)
# CHECK: "custom.op1"() ( {
# CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors
# CHECK: "custom.terminator"() : () -> ()
# CHECK: }) : () -> ()
terminator = ctx.create_operation("custom.terminator", loc)
block.operations.insert(0, terminator)
print(op1)
# Now add the whole operation to another op.
# TODO: Verify lifetime hazard by nulling out the new owning module and
# accessing op1.
# TODO: Also verify accessing the terminator once both parents are nulled
# out.
module = ctx.parse_module(r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""")
func = module.operation.regions[0].blocks[0].operations[0]
entry_block = func.regions[0].blocks[0]
entry_block.operations.insert(0, op1)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.terminator"
# CHECK: %0 = "custom.addi"
print(module)
run(testOperationWithRegion)