forked from OSchip/llvm-project
[mlir][Python] Custom python op view wrappers for building and traversing.
* Still rough edges that need more sugar but the bones are there. Notes left in the test case for things that can be improved. * Does not actually yield custom OpViews yet for traversing. Will rework that in a followup. Differential Revision: https://reviews.llvm.org/D89932
This commit is contained in:
parent
78ae1f6c90
commit
013b9322de
|
@ -4,6 +4,9 @@
|
|||
|
||||
set(PY_SRC_FILES
|
||||
mlir/__init__.py
|
||||
mlir/ir.py
|
||||
mlir/dialects/__init__.py
|
||||
mlir/dialects/std.py
|
||||
)
|
||||
|
||||
add_custom_target(MLIRBindingsPythonSources ALL
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
//===- Globals.h - MLIR Python extension globals --------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
|
||||
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
/// Globals that are always accessible once the extension has been initialized.
|
||||
class PyGlobals {
|
||||
public:
|
||||
PyGlobals();
|
||||
~PyGlobals();
|
||||
|
||||
/// Most code should get the globals via this static accessor.
|
||||
static PyGlobals &get() {
|
||||
assert(instance && "PyGlobals is null");
|
||||
return *instance;
|
||||
}
|
||||
|
||||
/// Get and set the list of parent modules to search for dialect
|
||||
/// implementation classes.
|
||||
std::vector<std::string> &getDialectSearchPrefixes() {
|
||||
return dialectSearchPrefixes;
|
||||
}
|
||||
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
|
||||
dialectSearchPrefixes.swap(newValues);
|
||||
}
|
||||
|
||||
/// Loads a python module corresponding to the given dialect namespace.
|
||||
/// No-ops if the module has already been loaded or is not found. Raises
|
||||
/// an error on any evaluation issues.
|
||||
/// Note that this returns void because it is expected that the module
|
||||
/// contains calls to decorators and helpers that register the salient
|
||||
/// entities.
|
||||
void loadDialectModule(const std::string &dialectNamespace);
|
||||
|
||||
/// Decorator for registering a custom Dialect class. The class object must
|
||||
/// have a DIALECT_NAMESPACE attribute.
|
||||
pybind11::object registerDialectDecorator(pybind11::object pyClass);
|
||||
|
||||
/// Adds a concrete implementation dialect class.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerDialectImpl(const std::string &dialectNamespace,
|
||||
pybind11::object pyClass);
|
||||
|
||||
/// Adds a concrete implementation operation class.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerOperationImpl(const std::string &operationName,
|
||||
pybind11::object pyClass,
|
||||
pybind11::object rawClass);
|
||||
|
||||
/// Looks up a registered dialect class by namespace. Note that this may
|
||||
/// trigger loading of the defining module and can arbitrarily re-enter.
|
||||
llvm::Optional<pybind11::object>
|
||||
lookupDialectClass(const std::string &dialectNamespace);
|
||||
|
||||
private:
|
||||
static PyGlobals *instance;
|
||||
/// Module name prefixes to search under for dialect implementation modules.
|
||||
std::vector<std::string> dialectSearchPrefixes;
|
||||
/// Map of dialect namespace to bool flag indicating whether the module has
|
||||
/// been successfully loaded or resolved to not found.
|
||||
llvm::StringSet<> loadedDialectModules;
|
||||
/// Map of dialect namespace to external dialect class object.
|
||||
llvm::StringMap<pybind11::object> dialectClassMap;
|
||||
/// Map of full operation name to external operation class object.
|
||||
llvm::StringMap<pybind11::object> operationClassMap;
|
||||
/// Map of operation name to custom subclass that directly initializes
|
||||
/// the OpView base class (bypassing the user class constructor).
|
||||
llvm::StringMap<pybind11::object> rawOperationClassMap;
|
||||
};
|
||||
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_BINDINGS_PYTHON_GLOBALS_H
|
|
@ -7,6 +7,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "IRModules.h"
|
||||
|
||||
#include "Globals.h"
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
|
@ -209,19 +211,27 @@ private:
|
|||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Type-checking utilities.
|
||||
// Utilities.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
|
||||
/// Checks whether the given type is an integer or float type.
|
||||
int mlirTypeIsAIntegerOrFloat(MlirType type) {
|
||||
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
|
||||
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
|
||||
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
static py::object
|
||||
createCustomDialectWrapper(const std::string &dialectNamespace,
|
||||
py::object dialectDescriptor) {
|
||||
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
|
||||
if (!dialectClass) {
|
||||
// Use the base class.
|
||||
return py::cast(PyDialect(std::move(dialectDescriptor)));
|
||||
}
|
||||
|
||||
// Create the custom implementation.
|
||||
return (*dialectClass)(std::move(dialectDescriptor));
|
||||
}
|
||||
//------------------------------------------------------------------------------
|
||||
// Collections.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -567,9 +577,11 @@ size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
|
|||
|
||||
py::object PyMlirContext::createOperation(
|
||||
std::string name, PyLocation location,
|
||||
llvm::Optional<std::vector<PyValue *>> operands,
|
||||
llvm::Optional<std::vector<PyType *>> results,
|
||||
llvm::Optional<py::dict> attributes,
|
||||
llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
|
||||
llvm::SmallVector<MlirValue, 4> mlirOperands;
|
||||
llvm::SmallVector<MlirType, 4> mlirResults;
|
||||
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
||||
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
|
||||
|
@ -578,6 +590,16 @@ py::object PyMlirContext::createOperation(
|
|||
if (regions < 0)
|
||||
throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
|
||||
|
||||
// Unpack/validate operands.
|
||||
if (operands) {
|
||||
mlirOperands.reserve(operands->size());
|
||||
for (PyValue *operand : *operands) {
|
||||
if (!operand)
|
||||
throw SetPyError(PyExc_ValueError, "operand value cannot be None");
|
||||
mlirOperands.push_back(operand->get());
|
||||
}
|
||||
}
|
||||
|
||||
// Unpack/validate results.
|
||||
if (results) {
|
||||
mlirResults.reserve(results->size());
|
||||
|
@ -614,6 +636,9 @@ py::object PyMlirContext::createOperation(
|
|||
// Apply unpacked/validated to the operation state. Beyond this
|
||||
// point, exceptions cannot be thrown or else the state will leak.
|
||||
MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
|
||||
if (!mlirOperands.empty())
|
||||
mlirOperationStateAddOperands(&state, mlirOperands.size(),
|
||||
mlirOperands.data());
|
||||
if (!mlirResults.empty())
|
||||
mlirOperationStateAddResults(&state, mlirResults.size(),
|
||||
mlirResults.data());
|
||||
|
@ -646,6 +671,24 @@ py::object PyMlirContext::createOperation(
|
|||
return PyOperation::createDetached(getRef(), operation).releaseObject();
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyDialect, PyDialectDescriptor, PyDialects
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
MlirDialect PyDialects::getDialectForKey(const std::string &key,
|
||||
bool attrError) {
|
||||
// If the "std" dialect was asked for, substitute the empty namespace :(
|
||||
static const std::string emptyKey;
|
||||
const std::string *canonKey = key == "std" ? &emptyKey : &key;
|
||||
MlirDialect dialect = mlirContextGetOrLoadDialect(
|
||||
getContext()->get(), {canonKey->data(), canonKey->size()});
|
||||
if (mlirDialectIsNull(dialect)) {
|
||||
throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
|
||||
llvm::Twine("Dialect '") + key + "' not found");
|
||||
}
|
||||
return dialect;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyModule
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -815,6 +858,45 @@ py::object PyOperation::getAsm(bool binary,
|
|||
return fileObject.attr("getvalue")();
|
||||
}
|
||||
|
||||
PyOpView::PyOpView(py::object operation)
|
||||
: operationObject(std::move(operation)),
|
||||
operation(py::cast<PyOperation *>(this->operationObject)) {}
|
||||
|
||||
py::object PyOpView::createRawSubclass(py::object userClass) {
|
||||
// This is... a little gross. The typical pattern is to have a pure python
|
||||
// class that extends OpView like:
|
||||
// class AddFOp(_cext.ir.OpView):
|
||||
// def __init__(self, loc, lhs, rhs):
|
||||
// operation = loc.context.create_operation(
|
||||
// "addf", lhs, rhs, results=[lhs.type])
|
||||
// super().__init__(operation)
|
||||
//
|
||||
// I.e. The goal of the user facing type is to provide a nice constructor
|
||||
// that has complete freedom for the op under construction. This is at odds
|
||||
// with our other desire to sometimes create this object by just passing an
|
||||
// operation (to initialize the base class). We could do *arg and **kwargs
|
||||
// munging to try to make it work, but instead, we synthesize a new class
|
||||
// on the fly which extends this user class (AddFOp in this example) and
|
||||
// *give it* the base class's __init__ method, thus bypassing the
|
||||
// intermediate subclass's __init__ method entirely. While slightly,
|
||||
// underhanded, this is safe/legal because the type hierarchy has not changed
|
||||
// (we just added a new leaf) and we aren't mucking around with __new__.
|
||||
// Typically, this new class will be stored on the original as "_Raw" and will
|
||||
// be used for casts and other things that need a variant of the class that
|
||||
// is initialized purely from an operation.
|
||||
py::object parentMetaclass =
|
||||
py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
|
||||
py::dict attributes;
|
||||
// TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
|
||||
// now.
|
||||
// auto opViewType = py::type::of<PyOpView>();
|
||||
auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
|
||||
attributes["__init__"] = opViewType.attr("__init__");
|
||||
py::str origName = userClass.attr("__name__");
|
||||
py::str newName = py::str("_") + origName;
|
||||
return parentMetaclass(newName, py::make_tuple(userClass), attributes);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyAttribute.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -966,6 +1048,41 @@ private:
|
|||
MlirBlock block;
|
||||
};
|
||||
|
||||
/// A list of operation results. Internally, these are stored as consecutive
|
||||
/// elements, random access is cheap. The result list is associated with the
|
||||
/// operation whose results these are, and extends the lifetime of this
|
||||
/// operation.
|
||||
class PyOpOperandList {
|
||||
public:
|
||||
PyOpOperandList(PyOperationRef operation) : operation(operation) {}
|
||||
|
||||
/// Returns the length of the result list.
|
||||
intptr_t dunderLen() {
|
||||
operation->checkValid();
|
||||
return mlirOperationGetNumOperands(operation->get());
|
||||
}
|
||||
|
||||
/// Returns `index`-th element in the result list.
|
||||
PyOpResult dunderGetItem(intptr_t index) {
|
||||
if (index < 0 || index >= dunderLen()) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds region");
|
||||
}
|
||||
PyValue value(operation, mlirOperationGetOperand(operation->get(), index));
|
||||
return PyOpResult(value);
|
||||
}
|
||||
|
||||
/// Defines a Python class in the bindings.
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyOpOperandList>(m, "OpOperandList")
|
||||
.def("__len__", &PyOpOperandList::dunderLen)
|
||||
.def("__getitem__", &PyOpOperandList::dunderGetItem);
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
/// A list of operation results. Internally, these are stored as consecutive
|
||||
/// elements, random access is cheap. The result list is associated with the
|
||||
/// operation whose results these are, and extends the lifetime of this
|
||||
|
@ -1914,7 +2031,9 @@ public:
|
|||
//------------------------------------------------------------------------------
|
||||
|
||||
void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of MlirContext
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyMlirContext>(m, "Context")
|
||||
.def(py::init<>(&PyMlirContext::createNewContextForInit))
|
||||
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
|
||||
|
@ -1928,6 +2047,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyMlirContext::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
|
||||
.def_property_readonly(
|
||||
"dialects",
|
||||
[](PyMlirContext &self) { return PyDialects(self.getRef()); },
|
||||
"Gets a container for accessing dialects by name")
|
||||
.def_property_readonly(
|
||||
"d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
|
||||
"Alias for 'dialect'")
|
||||
.def(
|
||||
"get_dialect_descriptor",
|
||||
[=](PyMlirContext &self, std::string &name) {
|
||||
MlirDialect dialect = mlirContextGetOrLoadDialect(
|
||||
self.get(), {name.data(), name.size()});
|
||||
if (mlirDialectIsNull(dialect)) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Dialect '") + name + "' not found");
|
||||
}
|
||||
return PyDialectDescriptor(self.getRef(), dialect);
|
||||
},
|
||||
"Gets or loads a dialect by name, returning its descriptor object")
|
||||
.def_property(
|
||||
"allow_unregistered_dialects",
|
||||
[](PyMlirContext &self) -> bool {
|
||||
|
@ -1937,8 +2075,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
mlirContextSetAllowUnregisteredDialects(self.get(), value);
|
||||
})
|
||||
.def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
|
||||
py::arg("location"), py::arg("results") = py::none(),
|
||||
py::arg("attributes") = py::none(),
|
||||
py::arg("location"), py::arg("operands") = py::none(),
|
||||
py::arg("results") = py::none(), py::arg("attributes") = py::none(),
|
||||
py::arg("successors") = py::none(), py::arg("regions") = 0,
|
||||
kContextCreateOperationDocstring)
|
||||
.def(
|
||||
|
@ -2009,6 +2147,62 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
kContextGetFileLocationDocstring, py::arg("filename"),
|
||||
py::arg("line"), py::arg("col"));
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyDialectDescriptor
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
|
||||
.def_property_readonly("namespace",
|
||||
[](PyDialectDescriptor &self) {
|
||||
MlirStringRef ns =
|
||||
mlirDialectGetNamespace(self.get());
|
||||
return py::str(ns.data, ns.length);
|
||||
})
|
||||
.def("__repr__", [](PyDialectDescriptor &self) {
|
||||
MlirStringRef ns = mlirDialectGetNamespace(self.get());
|
||||
std::string repr("<DialectDescriptor ");
|
||||
repr.append(ns.data, ns.length);
|
||||
repr.append(">");
|
||||
return repr;
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyDialects
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyDialects>(m, "Dialects")
|
||||
.def("__getitem__",
|
||||
[=](PyDialects &self, std::string keyName) {
|
||||
MlirDialect dialect =
|
||||
self.getDialectForKey(keyName, /*attrError=*/false);
|
||||
py::object descriptor =
|
||||
py::cast(PyDialectDescriptor{self.getContext(), dialect});
|
||||
return createCustomDialectWrapper(keyName, std::move(descriptor));
|
||||
})
|
||||
.def("__getattr__", [=](PyDialects &self, std::string attrName) {
|
||||
MlirDialect dialect =
|
||||
self.getDialectForKey(attrName, /*attrError=*/true);
|
||||
py::object descriptor =
|
||||
py::cast(PyDialectDescriptor{self.getContext(), dialect});
|
||||
return createCustomDialectWrapper(attrName, std::move(descriptor));
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyDialect
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyDialect>(m, "Dialect")
|
||||
.def(py::init<py::object>(), "descriptor")
|
||||
.def_property_readonly(
|
||||
"descriptor", [](PyDialect &self) { return self.getDescriptor(); })
|
||||
.def("__repr__", [](py::object self) {
|
||||
auto clazz = self.attr("__class__");
|
||||
return py::str("<Dialect ") +
|
||||
self.attr("descriptor").attr("namespace") + py::str(" (class ") +
|
||||
clazz.attr("__module__") + py::str(".") +
|
||||
clazz.attr("__name__") + py::str(")>");
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of Location
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyLocation>(m, "Location")
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
|
@ -2021,7 +2215,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
return printAccum.join();
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of Module
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyModule>(m, "Module")
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
|
||||
|
@ -2055,12 +2251,17 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
},
|
||||
kOperationStrDunderDocstring);
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of Operation.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyOperation>(m, "Operation")
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
[](PyOperation &self) { return self.getContext().getObject(); },
|
||||
"Context that owns the Operation")
|
||||
.def_property_readonly(
|
||||
"operands",
|
||||
[](PyOperation &self) { return PyOpOperandList(self.getRef()); })
|
||||
.def_property_readonly(
|
||||
"regions",
|
||||
[](PyOperation &self) { return PyRegionList(self.getRef()); })
|
||||
|
@ -2098,7 +2299,15 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
py::arg("print_generic_op_form") = false,
|
||||
py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
|
||||
|
||||
py::class_<PyOpView>(m, "OpView")
|
||||
.def(py::init<py::object>())
|
||||
.def_property_readonly("operation", &PyOpView::getOperationObject)
|
||||
.def("__str__",
|
||||
[](PyOpView &self) { return py::str(self.getOperationObject()); });
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyRegion.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyRegion>(m, "Region")
|
||||
.def_property_readonly(
|
||||
"blocks",
|
||||
|
@ -2123,7 +2332,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
}
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyBlock.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyBlock>(m, "Block")
|
||||
.def_property_readonly(
|
||||
"arguments",
|
||||
|
@ -2167,7 +2378,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
},
|
||||
"Returns the assembly form of the block.");
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyAttribute.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyAttribute>(m, "Attribute")
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
|
@ -2219,6 +2432,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
return printAccum.join();
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyNamedAttribute
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyNamedAttribute>(m, "NamedAttribute")
|
||||
.def("__repr__",
|
||||
[](PyNamedAttribute &self) {
|
||||
|
@ -2257,7 +2473,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
PyStringAttribute::bind(m);
|
||||
PyDenseElementsAttribute::bind(m);
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyType.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyType>(m, "Type")
|
||||
.def_property_readonly(
|
||||
"context", [](PyType &self) { return self.getContext().getObject(); },
|
||||
|
@ -2313,7 +2531,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
PyTupleType::bind(m);
|
||||
PyFunctionType::bind(m);
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of Value.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyValue>(m, "Value")
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
|
@ -2346,6 +2566,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
PyBlockList::bind(m);
|
||||
PyOperationIterator::bind(m);
|
||||
PyOperationList::bind(m);
|
||||
PyOpOperandList::bind(m);
|
||||
PyOpResultList::bind(m);
|
||||
PyRegionIterator::bind(m);
|
||||
PyRegionList::bind(m);
|
||||
|
|
|
@ -132,6 +132,7 @@ public:
|
|||
/// Creates an operation. See corresponding python docstring.
|
||||
pybind11::object
|
||||
createOperation(std::string name, PyLocation location,
|
||||
llvm::Optional<std::vector<PyValue *>> operands,
|
||||
llvm::Optional<std::vector<PyType *>> results,
|
||||
llvm::Optional<pybind11::dict> attributes,
|
||||
llvm::Optional<std::vector<PyBlock *>> successors,
|
||||
|
@ -187,6 +188,45 @@ private:
|
|||
PyMlirContextRef contextRef;
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
|
||||
/// order to differentiate it from the `Dialect` base class which is extended by
|
||||
/// plugins which extend dialect functionality through extension python code.
|
||||
/// This should be seen as the "low-level" object and `Dialect` as the
|
||||
/// high-level, user facing object.
|
||||
class PyDialectDescriptor : public BaseContextObject {
|
||||
public:
|
||||
PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
|
||||
: BaseContextObject(std::move(contextRef)), dialect(dialect) {}
|
||||
|
||||
MlirDialect get() { return dialect; }
|
||||
|
||||
private:
|
||||
MlirDialect dialect;
|
||||
};
|
||||
|
||||
/// User-level object for accessing dialects with dotted syntax such as:
|
||||
/// ctx.dialect.std
|
||||
class PyDialects : public BaseContextObject {
|
||||
public:
|
||||
PyDialects(PyMlirContextRef contextRef)
|
||||
: BaseContextObject(std::move(contextRef)) {}
|
||||
|
||||
MlirDialect getDialectForKey(const std::string &key, bool attrError);
|
||||
};
|
||||
|
||||
/// User-level dialect object. For dialects that have a registered extension,
|
||||
/// this will be the base class of the extension dialect type. For un-extended,
|
||||
/// objects of this type will be returned directly.
|
||||
class PyDialect {
|
||||
public:
|
||||
PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
|
||||
|
||||
pybind11::object getDescriptor() { return descriptor; }
|
||||
|
||||
private:
|
||||
pybind11::object descriptor;
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirLocation.
|
||||
class PyLocation : public BaseContextObject {
|
||||
public:
|
||||
|
@ -305,6 +345,24 @@ private:
|
|||
bool valid = true;
|
||||
};
|
||||
|
||||
/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
|
||||
/// providing more instance-specific accessors and serve as the base class for
|
||||
/// custom ODS-style operation classes. Since this class is subclass on the
|
||||
/// python side, it must present an __init__ method that operates in pure
|
||||
/// python types.
|
||||
class PyOpView {
|
||||
public:
|
||||
PyOpView(pybind11::object operation);
|
||||
|
||||
static pybind11::object createRawSubclass(pybind11::object userClass);
|
||||
|
||||
pybind11::object getOperationObject() { return operationObject; }
|
||||
|
||||
private:
|
||||
pybind11::object operationObject; // Holds the reference.
|
||||
PyOperation *operation; // For efficient, cast-free access from C++
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirRegion.
|
||||
/// Regions are managed completely by their containing operation. Unlike the
|
||||
/// C++ API, the python API does not support detached regions.
|
||||
|
|
|
@ -8,17 +8,155 @@
|
|||
|
||||
#include <tuple>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "Globals.h"
|
||||
#include "IRModules.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// PyGlobals
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
PyGlobals *PyGlobals::instance = nullptr;
|
||||
|
||||
PyGlobals::PyGlobals() {
|
||||
assert(!instance && "PyGlobals already constructed");
|
||||
instance = this;
|
||||
}
|
||||
|
||||
PyGlobals::~PyGlobals() { instance = nullptr; }
|
||||
|
||||
void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
|
||||
if (loadedDialectModules.contains(dialectNamespace))
|
||||
return;
|
||||
// Since re-entrancy is possible, make a copy of the search prefixes.
|
||||
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
|
||||
py::object loaded;
|
||||
for (std::string moduleName : localSearchPrefixes) {
|
||||
moduleName.push_back('.');
|
||||
moduleName.append(dialectNamespace);
|
||||
|
||||
try {
|
||||
loaded = py::module::import(moduleName.c_str());
|
||||
} catch (py::error_already_set &e) {
|
||||
if (e.matches(PyExc_ModuleNotFoundError)) {
|
||||
continue;
|
||||
} else {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
|
||||
// may have occurred, which may do anything.
|
||||
loadedDialectModules.insert(dialectNamespace);
|
||||
}
|
||||
|
||||
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
||||
py::object pyClass) {
|
||||
py::object &found = dialectClassMap[dialectNamespace];
|
||||
if (found) {
|
||||
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
|
||||
dialectNamespace +
|
||||
"' is already registered.");
|
||||
}
|
||||
found = std::move(pyClass);
|
||||
}
|
||||
|
||||
void PyGlobals::registerOperationImpl(const std::string &operationName,
|
||||
py::object pyClass, py::object rawClass) {
|
||||
py::object &found = operationClassMap[operationName];
|
||||
if (found) {
|
||||
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
|
||||
operationName +
|
||||
"' is already registered.");
|
||||
}
|
||||
found = std::move(pyClass);
|
||||
rawOperationClassMap[operationName] = std::move(rawClass);
|
||||
}
|
||||
|
||||
llvm::Optional<py::object>
|
||||
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
|
||||
loadDialectModule(dialectNamespace);
|
||||
// Fast match against the class map first (common case).
|
||||
const auto foundIt = dialectClassMap.find(dialectNamespace);
|
||||
if (foundIt != dialectClassMap.end()) {
|
||||
if (foundIt->second.is_none())
|
||||
return llvm::None;
|
||||
assert(foundIt->second && "py::object is defined");
|
||||
return foundIt->second;
|
||||
}
|
||||
|
||||
// Not found and loading did not yield a registration. Negative cache.
|
||||
dialectClassMap[dialectNamespace] = py::none();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Module initialization.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(_mlir, m) {
|
||||
m.doc() = "MLIR Python Native Extension";
|
||||
|
||||
py::class_<PyGlobals>(m, "_Globals")
|
||||
.def_property("dialect_search_modules",
|
||||
&PyGlobals::getDialectSearchPrefixes,
|
||||
&PyGlobals::setDialectSearchPrefixes)
|
||||
.def("append_dialect_search_prefix",
|
||||
[](PyGlobals &self, std::string moduleName) {
|
||||
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
|
||||
})
|
||||
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
|
||||
"Testing hook for directly registering a dialect")
|
||||
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
|
||||
"Testing hook for directly registering an operation");
|
||||
|
||||
// Aside from making the globals accessible to python, having python manage
|
||||
// it is necessary to make sure it is destroyed (and releases its python
|
||||
// resources) properly.
|
||||
m.attr("globals") =
|
||||
py::cast(new PyGlobals, py::return_value_policy::take_ownership);
|
||||
|
||||
// Registration decorators.
|
||||
m.def(
|
||||
"register_dialect",
|
||||
[](py::object pyClass) {
|
||||
std::string dialectNamespace =
|
||||
pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
|
||||
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
|
||||
return pyClass;
|
||||
},
|
||||
"Class decorator for registering a custom Dialect wrapper");
|
||||
m.def(
|
||||
"register_operation",
|
||||
[](py::object dialectClass) -> py::cpp_function {
|
||||
return py::cpp_function(
|
||||
[dialectClass](py::object opClass) -> py::object {
|
||||
std::string operationName =
|
||||
opClass.attr("OPERATION_NAME").cast<std::string>();
|
||||
auto rawSubclass = PyOpView::createRawSubclass(opClass);
|
||||
PyGlobals::get().registerOperationImpl(operationName, opClass,
|
||||
rawSubclass);
|
||||
|
||||
// Dict-stuff the new opClass by name onto the dialect class.
|
||||
py::object opClassName = opClass.attr("__name__");
|
||||
dialectClass.attr(opClassName) = opClass;
|
||||
|
||||
// Now create a special "Raw" subclass that passes through
|
||||
// construction to the OpView parent (bypasses the intermediate
|
||||
// child's __init__).
|
||||
opClass.attr("_Raw") = rawSubclass;
|
||||
return opClass;
|
||||
});
|
||||
},
|
||||
"Class decorator for registering a custom Operation wrapper");
|
||||
|
||||
// Define and populate IR submodule.
|
||||
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
|
||||
populateIRSubmodule(irModule);
|
||||
|
|
|
@ -8,4 +8,37 @@
|
|||
# and arbitrate any one-time initialization needed in various shared-library
|
||||
# scenarios.
|
||||
|
||||
from _mlir import *
|
||||
__all__ = [
|
||||
"ir",
|
||||
]
|
||||
|
||||
# Expose the corresponding C-Extension module with a well-known name at this
|
||||
# top-level module. This allows relative imports like the following to
|
||||
# function:
|
||||
# from .. import _cext
|
||||
# This reduces coupling, allowing embedding of the python sources into another
|
||||
# project that can just vary based on this top-level loader module.
|
||||
import _mlir as _cext
|
||||
|
||||
def _reexport_cext(cext_module_name, target_module_name):
|
||||
"""Re-exports a named sub-module of the C-Extension into another module.
|
||||
|
||||
Typically:
|
||||
from . import _reexport_cext
|
||||
_reexport_cext("ir", __name__)
|
||||
del _reexport_cext
|
||||
"""
|
||||
import sys
|
||||
target_module = sys.modules[target_module_name]
|
||||
source_module = getattr(_cext, cext_module_name)
|
||||
for attr_name in dir(source_module):
|
||||
if not attr_name.startswith("__"):
|
||||
setattr(target_module, attr_name, getattr(source_module, attr_name))
|
||||
|
||||
|
||||
# Import sub-modules. Since these may import from here, this must come after
|
||||
# any exported definitions.
|
||||
from . import ir
|
||||
|
||||
# Add our 'dialects' parent module to the search path for implementations.
|
||||
_cext.globals.append_dialect_search_prefix("mlir.dialects")
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Re-export the parent _cext so that every level of the API can get it locally.
|
||||
from .. import _cext
|
|
@ -0,0 +1,33 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# TODO: This file should be auto-generated.
|
||||
|
||||
from . import _cext
|
||||
|
||||
@_cext.register_dialect
|
||||
class _Dialect(_cext.ir.Dialect):
|
||||
# Special case: 'std' namespace aliases to the empty namespace.
|
||||
DIALECT_NAMESPACE = "std"
|
||||
pass
|
||||
|
||||
@_cext.register_operation(_Dialect)
|
||||
class AddFOp(_cext.ir.OpView):
|
||||
OPERATION_NAME = "std.addf"
|
||||
|
||||
def __init__(self, loc, lhs, rhs):
|
||||
super().__init__(loc.context.create_operation(
|
||||
"std.addf", loc, operands=[lhs, rhs], results=[lhs.type]))
|
||||
|
||||
@property
|
||||
def lhs(self):
|
||||
return self.operation.operands[0]
|
||||
|
||||
@property
|
||||
def rhs(self):
|
||||
return self.operation.operands[1]
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self.operation.results[0]
|
|
@ -0,0 +1,8 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Simply a wrapper around the extension module of the same name.
|
||||
from . import _reexport_cext
|
||||
_reexport_cext("ir", __name__)
|
||||
del _reexport_cext
|
|
@ -0,0 +1,107 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import gc
|
||||
import mlir
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
gc.collect()
|
||||
assert mlir.ir.Context._get_live_count() == 0
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDialectDescriptor
|
||||
def testDialectDescriptor():
|
||||
ctx = mlir.ir.Context()
|
||||
d = ctx.get_dialect_descriptor("std")
|
||||
# CHECK: <DialectDescriptor std>
|
||||
print(d)
|
||||
# CHECK: std
|
||||
print(d.namespace)
|
||||
try:
|
||||
_ = ctx.get_dialect_descriptor("not_existing")
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert False, "Expected exception"
|
||||
|
||||
run(testDialectDescriptor)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testUserDialectClass
|
||||
def testUserDialectClass():
|
||||
ctx = mlir.ir.Context()
|
||||
# Access using attribute.
|
||||
d = ctx.dialects.std
|
||||
# Note that the standard dialect namespace prints as ''. Others will print
|
||||
# as "<Dialect %namespace (..."
|
||||
# CHECK: <Dialect (class mlir.dialects.std._Dialect)>
|
||||
print(d)
|
||||
try:
|
||||
_ = ctx.dialects.not_existing
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
assert False, "Expected exception"
|
||||
|
||||
# Access using index.
|
||||
d = ctx.dialects["std"]
|
||||
# CHECK: <Dialect (class mlir.dialects.std._Dialect)>
|
||||
print(d)
|
||||
try:
|
||||
_ = ctx.dialects["not_existing"]
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
assert False, "Expected exception"
|
||||
|
||||
# Using the 'd' alias.
|
||||
d = ctx.d["std"]
|
||||
# CHECK: <Dialect (class mlir.dialects.std._Dialect)>
|
||||
print(d)
|
||||
|
||||
run(testUserDialectClass)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCustomOpView
|
||||
# This test uses the standard dialect AddFOp as an example of a user op.
|
||||
# TODO: Op creation and access is still quite verbose: simplify this test as
|
||||
# additional capabilities come online.
|
||||
def testCustomOpView():
|
||||
ctx = mlir.ir.Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
f32 = mlir.ir.F32Type.get(ctx)
|
||||
loc = ctx.get_unknown_location()
|
||||
m = ctx.create_module(loc)
|
||||
m_block = m.operation.regions[0].blocks[0]
|
||||
# TODO: Remove integer insertion in favor of InsertionPoint and/or op-based.
|
||||
ip = [0]
|
||||
def createInput():
|
||||
op = ctx.create_operation("pytest_dummy.intinput", loc, results=[f32])
|
||||
m_block.operations.insert(ip[0], op)
|
||||
ip[0] += 1
|
||||
# TODO: Auto result cast from operation
|
||||
return op.results[0]
|
||||
|
||||
# Create via dialects context collection.
|
||||
input1 = createInput()
|
||||
input2 = createInput()
|
||||
op1 = ctx.dialects.std.AddFOp(loc, input1, input2)
|
||||
# TODO: Auto operation cast from OpView
|
||||
# TODO: Context manager insertion point
|
||||
m_block.operations.insert(ip[0], op1.operation)
|
||||
ip[0] += 1
|
||||
|
||||
# Create via an import
|
||||
from mlir.dialects.std import AddFOp
|
||||
op2 = AddFOp(loc, input1, op1.result)
|
||||
m_block.operations.insert(ip[0], op2.operation)
|
||||
ip[0] += 1
|
||||
|
||||
# CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
|
||||
# CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
|
||||
# CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
|
||||
# CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
|
||||
m.operation.print()
|
||||
|
||||
run(testCustomOpView)
|
Loading…
Reference in New Issue