diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 8eab7dab1675..3fad701d1641 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -24,6 +24,22 @@ using llvm::SmallVector; // Docstrings (trivial, non-duplicated docstrings are included inline). //------------------------------------------------------------------------------ +static const char kContextCreateOperationDocstring[] = + R"(Creates a new operation. + +Args: + name: Operation name (e.g. "dialect.operation"). + location: A Location object. + results: Sequence of Type representing op result types. + attributes: Dict of str:Attribute. + successors: List of Block for the operation's successors. + regions: Number of regions to create. + +Returns: + A new "detached" Operation object. Detached operations can be added + to blocks, which causes them to become "attached." +)"; + static const char kContextParseDocstring[] = R"(Parses a module's assembly format from a string. @@ -60,6 +76,13 @@ static const char kTypeStrDunderDocstring[] = static const char kDumpDocstring[] = R"(Dumps a debug representation of the object to stderr.)"; +static const char kAppendBlockDocstring[] = + R"(Appends a new block, with argument types as positional args. + +Returns: + The created block. +)"; + //------------------------------------------------------------------------------ // Conversion utilities. //------------------------------------------------------------------------------ @@ -265,11 +288,25 @@ public: throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); } + PyBlock appendBlock(py::args pyArgTypes) { + operation->checkValid(); + llvm::SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast().type); + } + + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + mlirRegionAppendOwnedBlock(region, block); + return PyBlock(operation, block); + } + static void bind(py::module &m) { py::class_(m, "BlockList") .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) - .def("__len__", &PyBlockList::dunderLen); + .def("__len__", &PyBlockList::dunderLen) + .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); } private: @@ -352,11 +389,41 @@ public: "attempt to access out of bounds operation"); } + void insert(int index, PyOperation &newOperation) { + parentOperation->checkValid(); + newOperation.checkValid(); + if (index < 0) { + throw SetPyError( + PyExc_IndexError, + "only positive insertion indices are supported for operations"); + } + if (newOperation.isAttached()) { + throw SetPyError( + PyExc_ValueError, + "attempt to insert an operation that has already been inserted"); + } + // TODO: Needing to do this check is unfortunate, especially since it will + // be a forward-scan, just like the following call to + // mlirBlockInsertOwnedOperation. Switch to insert before/after once + // D88148 lands. + if (index > dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to insert operation past end"); + } + mlirBlockInsertOwnedOperation(block, index, newOperation.get()); + newOperation.setAttached(); + // TODO: Rework the parentKeepAlive so as to avoid ownership hazards under + // the new ownership. + } + static void bind(py::module &m) { py::class_(m, "OperationList") .def("__getitem__", &PyOperationList::dunderGetItem) .def("__iter__", &PyOperationList::dunderIter) - .def("__len__", &PyOperationList::dunderLen); + .def("__len__", &PyOperationList::dunderLen) + .def("insert", &PyOperationList::insert, py::arg("index"), + py::arg("operation"), + "Inserts an operation at an indexed position"); } private: @@ -416,6 +483,87 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } +py::object PyMlirContext::createOperation( + std::string name, PyLocation location, + llvm::Optional> results, + llvm::Optional attributes, + llvm::Optional> successors, int regions) { + llvm::SmallVector mlirResults; + llvm::SmallVector mlirSuccessors; + llvm::SmallVector, 4> mlirAttributes; + + // General parameter validation. + if (regions < 0) + throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); + + // Unpack/validate results. + if (results) { + mlirResults.reserve(results->size()); + for (PyType *result : *results) { + // TODO: Verify result type originate from the same context. + if (!result) + throw SetPyError(PyExc_ValueError, "result type cannot be None"); + mlirResults.push_back(result->type); + } + } + // Unpack/validate attributes. + if (attributes) { + mlirAttributes.reserve(attributes->size()); + for (auto &it : *attributes) { + + auto name = it.first.cast(); + auto &attribute = it.second.cast(); + // TODO: Verify attribute originates from the same context. + mlirAttributes.emplace_back(std::move(name), attribute.attr); + } + } + // Unpack/validate successors. + if (successors) { + llvm::SmallVector mlirSuccessors; + mlirSuccessors.reserve(successors->size()); + for (auto *successor : *successors) { + // TODO: Verify successor originate from the same context. + if (!successor) + throw SetPyError(PyExc_ValueError, "successor block cannot be None"); + mlirSuccessors.push_back(successor->get()); + } + } + + // Apply unpacked/validated to the operation state. Beyond this + // point, exceptions cannot be thrown or else the state will leak. + MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc); + if (!mlirResults.empty()) + mlirOperationStateAddResults(&state, mlirResults.size(), + mlirResults.data()); + if (!mlirAttributes.empty()) { + // Note that the attribute names directly reference bytes in + // mlirAttributes, so that vector must not be changed from here + // on. + llvm::SmallVector mlirNamedAttributes; + mlirNamedAttributes.reserve(mlirAttributes.size()); + for (auto &it : mlirAttributes) + mlirNamedAttributes.push_back( + mlirNamedAttributeGet(it.first.c_str(), it.second)); + mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + } + if (!mlirSuccessors.empty()) + mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), + mlirSuccessors.data()); + if (regions) { + llvm::SmallVector mlirRegions; + mlirRegions.resize(regions); + for (int i = 0; i < regions; ++i) + mlirRegions[i] = mlirRegionCreate(); + mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), + mlirRegions.data()); + } + + // Construct the operation. + MlirOperation operation = mlirOperationCreate(&state); + return PyOperation::createDetached(getRef(), operation).releaseObject(); +} + //------------------------------------------------------------------------------ // PyModule //------------------------------------------------------------------------------ @@ -1153,6 +1301,11 @@ void mlir::python::populateIRSubmodule(py::module &m) { [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) + .def("create_operation", &PyMlirContext::createOperation, py::arg("name"), + py::arg("location"), py::arg("results") = py::none(), + py::arg("attributes") = py::none(), + py::arg("successors") = py::none(), py::arg("regions") = 0, + kContextCreateOperationDocstring) .def( "parse_module", [](PyMlirContext &self, const std::string moduleAsm) { diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h index 06b697cfd786..41b18d216026 100644 --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -17,9 +17,12 @@ namespace mlir { namespace python { +class PyBlock; +class PyLocation; class PyMlirContext; class PyModule; class PyOperation; +class PyType; /// Template for a reference to a concrete type which captures a python /// reference to its underlying python object. @@ -112,6 +115,14 @@ public: /// Used for testing. size_t getLiveOperationCount(); + /// Creates an operation. See corresponding python docstring. + pybind11::object + createOperation(std::string name, PyLocation location, + llvm::Optional> results, + llvm::Optional attributes, + llvm::Optional> successors, + int regions); + private: PyMlirContext(MlirContext context); // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, @@ -227,6 +238,10 @@ public: } bool isAttached() { return attached; } + void setAttached() { + assert(!attached && "operation already attached"); + attached = true; + } void checkValid(); private: diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 29ab06a25055..0435aa461809 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -12,8 +12,16 @@ #include #include +#include "llvm/ADT/Optional.h" #include "llvm/ADT/Twine.h" +namespace pybind11 { +namespace detail { +template +struct type_caster> : optional_caster> {}; +} // namespace detail +} // namespace pybind11 + namespace mlir { namespace python { diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py index 9522e4b1ad98..881398e1eba3 100644 --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -99,3 +99,99 @@ def testTraverseOpRegionBlockIndices(): walk_operations("", module.operation) run(testTraverseOpRegionBlockIndices) + + +# CHECK-LABEL: TEST: testDetachedOperation +def testDetachedOperation(): + ctx = mlir.ir.Context() + ctx.allow_unregistered_dialects = True + loc = ctx.get_unknown_location() + i32 = mlir.ir.IntegerType.get_signed(ctx, 32) + op1 = ctx.create_operation( + "custom.op1", loc, results=[i32, i32], regions=1, attributes={ + "foo": mlir.ir.StringAttr.get(ctx, "foo_value"), + "bar": mlir.ir.StringAttr.get(ctx, "bar_value"), + }) + # CHECK: %0:2 = "custom.op1"() ( { + # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) + print(op1) + + # TODO: Check successors once enough infra exists to do it properly. + +run(testDetachedOperation) + + +# CHECK-LABEL: TEST: testOperationInsert +def testOperationInsert(): + ctx = mlir.ir.Context() + ctx.allow_unregistered_dialects = True + module = ctx.parse_module(r""" + func @f1(%arg0: i32) -> i32 { + %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 + return %1 : i32 + } + """) + + # Create test op. + loc = ctx.get_unknown_location() + op1 = ctx.create_operation("custom.op1", loc) + op2 = ctx.create_operation("custom.op2", loc) + + func = module.operation.regions[0].blocks[0].operations[0] + entry_block = func.regions[0].blocks[0] + entry_block.operations.insert(0, op1) + entry_block.operations.insert(1, op2) + # CHECK: func @f1 + # CHECK: "custom.op1"() + # CHECK: "custom.op2"() + # CHECK: %0 = "custom.addi" + print(module) + + # Trying to add a previously added op should raise. + try: + entry_block.operations.insert(0, op1) + except ValueError: + pass + else: + assert False, "expected insert of attached op to raise" + +run(testOperationInsert) + + +# CHECK-LABEL: TEST: testOperationWithRegion +def testOperationWithRegion(): + ctx = mlir.ir.Context() + ctx.allow_unregistered_dialects = True + loc = ctx.get_unknown_location() + i32 = mlir.ir.IntegerType.get_signed(ctx, 32) + op1 = ctx.create_operation("custom.op1", loc, regions=1) + block = op1.regions[0].blocks.append(i32, i32) + # CHECK: "custom.op1"() ( { + # CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors + # CHECK: "custom.terminator"() : () -> () + # CHECK: }) : () -> () + terminator = ctx.create_operation("custom.terminator", loc) + block.operations.insert(0, terminator) + print(op1) + + # Now add the whole operation to another op. + # TODO: Verify lifetime hazard by nulling out the new owning module and + # accessing op1. + # TODO: Also verify accessing the terminator once both parents are nulled + # out. + module = ctx.parse_module(r""" + func @f1(%arg0: i32) -> i32 { + %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 + return %1 : i32 + } + """) + func = module.operation.regions[0].blocks[0].operations[0] + entry_block = func.regions[0].blocks[0] + entry_block.operations.insert(0, op1) + # CHECK: func @f1 + # CHECK: "custom.op1"() + # CHECK: "custom.terminator" + # CHECK: %0 = "custom.addi" + print(module) + +run(testOperationWithRegion)