[mlir][Python] Custom python op view wrappers for building and traversing.

* Still rough edges that need more sugar but the bones are there. Notes left in the test case for things that can be improved.
* Does not actually yield custom OpViews yet for traversing. Will rework that in a followup.

Differential Revision: https://reviews.llvm.org/D89932
This commit is contained in:
Stella Laurenzo 2020-10-21 23:34:01 -07:00
parent 78ae1f6c90
commit 013b9322de
10 changed files with 710 additions and 9 deletions

View File

@ -4,6 +4,9 @@
set(PY_SRC_FILES
mlir/__init__.py
mlir/ir.py
mlir/dialects/__init__.py
mlir/dialects/std.py
)
add_custom_target(MLIRBindingsPythonSources ALL

View File

@ -0,0 +1,94 @@
//===- Globals.h - MLIR Python extension globals --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
#include <string>
#include <vector>
#include "PybindUtils.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
namespace mlir {
namespace python {
/// Globals that are always accessible once the extension has been initialized.
class PyGlobals {
public:
PyGlobals();
~PyGlobals();
/// Most code should get the globals via this static accessor.
static PyGlobals &get() {
assert(instance && "PyGlobals is null");
return *instance;
}
/// Get and set the list of parent modules to search for dialect
/// implementation classes.
std::vector<std::string> &getDialectSearchPrefixes() {
return dialectSearchPrefixes;
}
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
dialectSearchPrefixes.swap(newValues);
}
/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
/// an error on any evaluation issues.
/// Note that this returns void because it is expected that the module
/// contains calls to decorators and helpers that register the salient
/// entities.
void loadDialectModule(const std::string &dialectNamespace);
/// Decorator for registering a custom Dialect class. The class object must
/// have a DIALECT_NAMESPACE attribute.
pybind11::object registerDialectDecorator(pybind11::object pyClass);
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerDialectImpl(const std::string &dialectNamespace,
pybind11::object pyClass);
/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass,
pybind11::object rawClass);
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
llvm::Optional<pybind11::object>
lookupDialectClass(const std::string &dialectNamespace);
private:
static PyGlobals *instance;
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to bool flag indicating whether the module has
/// been successfully loaded or resolved to not found.
llvm::StringSet<> loadedDialectModules;
/// Map of dialect namespace to external dialect class object.
llvm::StringMap<pybind11::object> dialectClassMap;
/// Map of full operation name to external operation class object.
llvm::StringMap<pybind11::object> operationClassMap;
/// Map of operation name to custom subclass that directly initializes
/// the OpView base class (bypassing the user class constructor).
llvm::StringMap<pybind11::object> rawOperationClassMap;
};
} // namespace python
} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_GLOBALS_H

View File

@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "IRModules.h"
#include "Globals.h"
#include "PybindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
@ -209,19 +211,27 @@ private:
} // namespace
//------------------------------------------------------------------------------
// Type-checking utilities.
// Utilities.
//------------------------------------------------------------------------------
namespace {
/// Checks whether the given type is an integer or float type.
int mlirTypeIsAIntegerOrFloat(MlirType type) {
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
} // namespace
static py::object
createCustomDialectWrapper(const std::string &dialectNamespace,
py::object dialectDescriptor) {
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
if (!dialectClass) {
// Use the base class.
return py::cast(PyDialect(std::move(dialectDescriptor)));
}
// Create the custom implementation.
return (*dialectClass)(std::move(dialectDescriptor));
}
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@ -567,9 +577,11 @@ size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
py::object PyMlirContext::createOperation(
std::string name, PyLocation location,
llvm::Optional<std::vector<PyValue *>> operands,
llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<py::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
llvm::SmallVector<MlirValue, 4> mlirOperands;
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
@ -578,6 +590,16 @@ py::object PyMlirContext::createOperation(
if (regions < 0)
throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
// Unpack/validate operands.
if (operands) {
mlirOperands.reserve(operands->size());
for (PyValue *operand : *operands) {
if (!operand)
throw SetPyError(PyExc_ValueError, "operand value cannot be None");
mlirOperands.push_back(operand->get());
}
}
// Unpack/validate results.
if (results) {
mlirResults.reserve(results->size());
@ -614,6 +636,9 @@ py::object PyMlirContext::createOperation(
// Apply unpacked/validated to the operation state. Beyond this
// point, exceptions cannot be thrown or else the state will leak.
MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
if (!mlirOperands.empty())
mlirOperationStateAddOperands(&state, mlirOperands.size(),
mlirOperands.data());
if (!mlirResults.empty())
mlirOperationStateAddResults(&state, mlirResults.size(),
mlirResults.data());
@ -646,6 +671,24 @@ py::object PyMlirContext::createOperation(
return PyOperation::createDetached(getRef(), operation).releaseObject();
}
//------------------------------------------------------------------------------
// PyDialect, PyDialectDescriptor, PyDialects
//------------------------------------------------------------------------------
MlirDialect PyDialects::getDialectForKey(const std::string &key,
bool attrError) {
// If the "std" dialect was asked for, substitute the empty namespace :(
static const std::string emptyKey;
const std::string *canonKey = key == "std" ? &emptyKey : &key;
MlirDialect dialect = mlirContextGetOrLoadDialect(
getContext()->get(), {canonKey->data(), canonKey->size()});
if (mlirDialectIsNull(dialect)) {
throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
llvm::Twine("Dialect '") + key + "' not found");
}
return dialect;
}
//------------------------------------------------------------------------------
// PyModule
//------------------------------------------------------------------------------
@ -815,6 +858,45 @@ py::object PyOperation::getAsm(bool binary,
return fileObject.attr("getvalue")();
}
PyOpView::PyOpView(py::object operation)
: operationObject(std::move(operation)),
operation(py::cast<PyOperation *>(this->operationObject)) {}
py::object PyOpView::createRawSubclass(py::object userClass) {
// This is... a little gross. The typical pattern is to have a pure python
// class that extends OpView like:
// class AddFOp(_cext.ir.OpView):
// def __init__(self, loc, lhs, rhs):
// operation = loc.context.create_operation(
// "addf", lhs, rhs, results=[lhs.type])
// super().__init__(operation)
//
// I.e. The goal of the user facing type is to provide a nice constructor
// that has complete freedom for the op under construction. This is at odds
// with our other desire to sometimes create this object by just passing an
// operation (to initialize the base class). We could do *arg and **kwargs
// munging to try to make it work, but instead, we synthesize a new class
// on the fly which extends this user class (AddFOp in this example) and
// *give it* the base class's __init__ method, thus bypassing the
// intermediate subclass's __init__ method entirely. While slightly,
// underhanded, this is safe/legal because the type hierarchy has not changed
// (we just added a new leaf) and we aren't mucking around with __new__.
// Typically, this new class will be stored on the original as "_Raw" and will
// be used for casts and other things that need a variant of the class that
// is initialized purely from an operation.
py::object parentMetaclass =
py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
py::dict attributes;
// TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
// now.
// auto opViewType = py::type::of<PyOpView>();
auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
attributes["__init__"] = opViewType.attr("__init__");
py::str origName = userClass.attr("__name__");
py::str newName = py::str("_") + origName;
return parentMetaclass(newName, py::make_tuple(userClass), attributes);
}
//------------------------------------------------------------------------------
// PyAttribute.
//------------------------------------------------------------------------------
@ -966,6 +1048,41 @@ private:
MlirBlock block;
};
/// A list of operation results. Internally, these are stored as consecutive
/// elements, random access is cheap. The result list is associated with the
/// operation whose results these are, and extends the lifetime of this
/// operation.
class PyOpOperandList {
public:
PyOpOperandList(PyOperationRef operation) : operation(operation) {}
/// Returns the length of the result list.
intptr_t dunderLen() {
operation->checkValid();
return mlirOperationGetNumOperands(operation->get());
}
/// Returns `index`-th element in the result list.
PyOpResult dunderGetItem(intptr_t index) {
if (index < 0 || index >= dunderLen()) {
throw SetPyError(PyExc_IndexError,
"attempt to access out of bounds region");
}
PyValue value(operation, mlirOperationGetOperand(operation->get(), index));
return PyOpResult(value);
}
/// Defines a Python class in the bindings.
static void bind(py::module &m) {
py::class_<PyOpOperandList>(m, "OpOperandList")
.def("__len__", &PyOpOperandList::dunderLen)
.def("__getitem__", &PyOpOperandList::dunderGetItem);
}
private:
PyOperationRef operation;
};
/// A list of operation results. Internally, these are stored as consecutive
/// elements, random access is cheap. The result list is associated with the
/// operation whose results these are, and extends the lifetime of this
@ -1914,7 +2031,9 @@ public:
//------------------------------------------------------------------------------
void mlir::python::populateIRSubmodule(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of MlirContext
//----------------------------------------------------------------------------
py::class_<PyMlirContext>(m, "Context")
.def(py::init<>(&PyMlirContext::createNewContextForInit))
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
@ -1928,6 +2047,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def_property_readonly(
"dialects",
[](PyMlirContext &self) { return PyDialects(self.getRef()); },
"Gets a container for accessing dialects by name")
.def_property_readonly(
"d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
"Alias for 'dialect'")
.def(
"get_dialect_descriptor",
[=](PyMlirContext &self, std::string &name) {
MlirDialect dialect = mlirContextGetOrLoadDialect(
self.get(), {name.data(), name.size()});
if (mlirDialectIsNull(dialect)) {
throw SetPyError(PyExc_ValueError,
llvm::Twine("Dialect '") + name + "' not found");
}
return PyDialectDescriptor(self.getRef(), dialect);
},
"Gets or loads a dialect by name, returning its descriptor object")
.def_property(
"allow_unregistered_dialects",
[](PyMlirContext &self) -> bool {
@ -1937,8 +2075,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
})
.def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
py::arg("location"), py::arg("results") = py::none(),
py::arg("attributes") = py::none(),
py::arg("location"), py::arg("operands") = py::none(),
py::arg("results") = py::none(), py::arg("attributes") = py::none(),
py::arg("successors") = py::none(), py::arg("regions") = 0,
kContextCreateOperationDocstring)
.def(
@ -2009,6 +2147,62 @@ void mlir::python::populateIRSubmodule(py::module &m) {
kContextGetFileLocationDocstring, py::arg("filename"),
py::arg("line"), py::arg("col"));
//----------------------------------------------------------------------------
// Mapping of PyDialectDescriptor
//----------------------------------------------------------------------------
py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
.def_property_readonly("namespace",
[](PyDialectDescriptor &self) {
MlirStringRef ns =
mlirDialectGetNamespace(self.get());
return py::str(ns.data, ns.length);
})
.def("__repr__", [](PyDialectDescriptor &self) {
MlirStringRef ns = mlirDialectGetNamespace(self.get());
std::string repr("<DialectDescriptor ");
repr.append(ns.data, ns.length);
repr.append(">");
return repr;
});
//----------------------------------------------------------------------------
// Mapping of PyDialects
//----------------------------------------------------------------------------
py::class_<PyDialects>(m, "Dialects")
.def("__getitem__",
[=](PyDialects &self, std::string keyName) {
MlirDialect dialect =
self.getDialectForKey(keyName, /*attrError=*/false);
py::object descriptor =
py::cast(PyDialectDescriptor{self.getContext(), dialect});
return createCustomDialectWrapper(keyName, std::move(descriptor));
})
.def("__getattr__", [=](PyDialects &self, std::string attrName) {
MlirDialect dialect =
self.getDialectForKey(attrName, /*attrError=*/true);
py::object descriptor =
py::cast(PyDialectDescriptor{self.getContext(), dialect});
return createCustomDialectWrapper(attrName, std::move(descriptor));
});
//----------------------------------------------------------------------------
// Mapping of PyDialect
//----------------------------------------------------------------------------
py::class_<PyDialect>(m, "Dialect")
.def(py::init<py::object>(), "descriptor")
.def_property_readonly(
"descriptor", [](PyDialect &self) { return self.getDescriptor(); })
.def("__repr__", [](py::object self) {
auto clazz = self.attr("__class__");
return py::str("<Dialect ") +
self.attr("descriptor").attr("namespace") + py::str(" (class ") +
clazz.attr("__module__") + py::str(".") +
clazz.attr("__name__") + py::str(")>");
});
//----------------------------------------------------------------------------
// Mapping of Location
//----------------------------------------------------------------------------
py::class_<PyLocation>(m, "Location")
.def_property_readonly(
"context",
@ -2021,7 +2215,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return printAccum.join();
});
//----------------------------------------------------------------------------
// Mapping of Module
//----------------------------------------------------------------------------
py::class_<PyModule>(m, "Module")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
@ -2055,12 +2251,17 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
kOperationStrDunderDocstring);
//----------------------------------------------------------------------------
// Mapping of Operation.
//----------------------------------------------------------------------------
py::class_<PyOperation>(m, "Operation")
.def_property_readonly(
"context",
[](PyOperation &self) { return self.getContext().getObject(); },
"Context that owns the Operation")
.def_property_readonly(
"operands",
[](PyOperation &self) { return PyOpOperandList(self.getRef()); })
.def_property_readonly(
"regions",
[](PyOperation &self) { return PyRegionList(self.getRef()); })
@ -2098,7 +2299,15 @@ void mlir::python::populateIRSubmodule(py::module &m) {
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
py::class_<PyOpView>(m, "OpView")
.def(py::init<py::object>())
.def_property_readonly("operation", &PyOpView::getOperationObject)
.def("__str__",
[](PyOpView &self) { return py::str(self.getOperationObject()); });
//----------------------------------------------------------------------------
// Mapping of PyRegion.
//----------------------------------------------------------------------------
py::class_<PyRegion>(m, "Region")
.def_property_readonly(
"blocks",
@ -2123,7 +2332,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
}
});
//----------------------------------------------------------------------------
// Mapping of PyBlock.
//----------------------------------------------------------------------------
py::class_<PyBlock>(m, "Block")
.def_property_readonly(
"arguments",
@ -2167,7 +2378,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
"Returns the assembly form of the block.");
//----------------------------------------------------------------------------
// Mapping of PyAttribute.
//----------------------------------------------------------------------------
py::class_<PyAttribute>(m, "Attribute")
.def_property_readonly(
"context",
@ -2219,6 +2432,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return printAccum.join();
});
//----------------------------------------------------------------------------
// Mapping of PyNamedAttribute
//----------------------------------------------------------------------------
py::class_<PyNamedAttribute>(m, "NamedAttribute")
.def("__repr__",
[](PyNamedAttribute &self) {
@ -2257,7 +2473,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyStringAttribute::bind(m);
PyDenseElementsAttribute::bind(m);
//----------------------------------------------------------------------------
// Mapping of PyType.
//----------------------------------------------------------------------------
py::class_<PyType>(m, "Type")
.def_property_readonly(
"context", [](PyType &self) { return self.getContext().getObject(); },
@ -2313,7 +2531,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
//----------------------------------------------------------------------------
// Mapping of Value.
//----------------------------------------------------------------------------
py::class_<PyValue>(m, "Value")
.def_property_readonly(
"context",
@ -2346,6 +2566,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyBlockList::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
PyOpOperandList::bind(m);
PyOpResultList::bind(m);
PyRegionIterator::bind(m);
PyRegionList::bind(m);

View File

@ -132,6 +132,7 @@ public:
/// Creates an operation. See corresponding python docstring.
pybind11::object
createOperation(std::string name, PyLocation location,
llvm::Optional<std::vector<PyValue *>> operands,
llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<pybind11::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors,
@ -187,6 +188,45 @@ private:
PyMlirContextRef contextRef;
};
/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
/// order to differentiate it from the `Dialect` base class which is extended by
/// plugins which extend dialect functionality through extension python code.
/// This should be seen as the "low-level" object and `Dialect` as the
/// high-level, user facing object.
class PyDialectDescriptor : public BaseContextObject {
public:
PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
: BaseContextObject(std::move(contextRef)), dialect(dialect) {}
MlirDialect get() { return dialect; }
private:
MlirDialect dialect;
};
/// User-level object for accessing dialects with dotted syntax such as:
/// ctx.dialect.std
class PyDialects : public BaseContextObject {
public:
PyDialects(PyMlirContextRef contextRef)
: BaseContextObject(std::move(contextRef)) {}
MlirDialect getDialectForKey(const std::string &key, bool attrError);
};
/// User-level dialect object. For dialects that have a registered extension,
/// this will be the base class of the extension dialect type. For un-extended,
/// objects of this type will be returned directly.
class PyDialect {
public:
PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
pybind11::object getDescriptor() { return descriptor; }
private:
pybind11::object descriptor;
};
/// Wrapper around an MlirLocation.
class PyLocation : public BaseContextObject {
public:
@ -305,6 +345,24 @@ private:
bool valid = true;
};
/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
/// providing more instance-specific accessors and serve as the base class for
/// custom ODS-style operation classes. Since this class is subclass on the
/// python side, it must present an __init__ method that operates in pure
/// python types.
class PyOpView {
public:
PyOpView(pybind11::object operation);
static pybind11::object createRawSubclass(pybind11::object userClass);
pybind11::object getOperationObject() { return operationObject; }
private:
pybind11::object operationObject; // Holds the reference.
PyOperation *operation; // For efficient, cast-free access from C++
};
/// Wrapper around an MlirRegion.
/// Regions are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached regions.

View File

@ -8,17 +8,155 @@
#include <tuple>
#include <pybind11/pybind11.h>
#include "PybindUtils.h"
#include "Globals.h"
#include "IRModules.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
assert(!instance && "PyGlobals already constructed");
instance = this;
}
PyGlobals::~PyGlobals() { instance = nullptr; }
void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
if (loadedDialectModules.contains(dialectNamespace))
return;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded;
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace);
try {
loaded = py::module::import(moduleName.c_str());
} catch (py::error_already_set &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
continue;
} else {
throw;
}
}
break;
}
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
loadedDialectModules.insert(dialectNamespace);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
dialectNamespace +
"' is already registered.");
}
found = std::move(pyClass);
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass, py::object rawClass) {
py::object &found = operationClassMap[operationName];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
operationName +
"' is already registered.");
}
found = std::move(pyClass);
rawOperationClassMap[operationName] = std::move(rawClass);
}
llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
// Fast match against the class map first (common case).
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
dialectClassMap[dialectNamespace] = py::none();
return llvm::None;
}
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
PYBIND11_MODULE(_mlir, m) {
m.doc() = "MLIR Python Native Extension";
py::class_<PyGlobals>(m, "_Globals")
.def_property("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
.def("append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
})
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
// it is necessary to make sure it is destroyed (and releases its python
// resources) properly.
m.attr("globals") =
py::cast(new PyGlobals, py::return_value_policy::take_ownership);
// Registration decorators.
m.def(
"register_dialect",
[](py::object pyClass) {
std::string dialectNamespace =
pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
return pyClass;
},
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
[](py::object dialectClass) -> py::cpp_function {
return py::cpp_function(
[dialectClass](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
auto rawSubclass = PyOpView::createRawSubclass(opClass);
PyGlobals::get().registerOperationImpl(operationName, opClass,
rawSubclass);
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
// Now create a special "Raw" subclass that passes through
// construction to the OpView parent (bypasses the intermediate
// child's __init__).
opClass.attr("_Raw") = rawSubclass;
return opClass;
});
},
"Class decorator for registering a custom Operation wrapper");
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
populateIRSubmodule(irModule);

View File

@ -8,4 +8,37 @@
# and arbitrate any one-time initialization needed in various shared-library
# scenarios.
from _mlir import *
__all__ = [
"ir",
]
# Expose the corresponding C-Extension module with a well-known name at this
# top-level module. This allows relative imports like the following to
# function:
# from .. import _cext
# This reduces coupling, allowing embedding of the python sources into another
# project that can just vary based on this top-level loader module.
import _mlir as _cext
def _reexport_cext(cext_module_name, target_module_name):
"""Re-exports a named sub-module of the C-Extension into another module.
Typically:
from . import _reexport_cext
_reexport_cext("ir", __name__)
del _reexport_cext
"""
import sys
target_module = sys.modules[target_module_name]
source_module = getattr(_cext, cext_module_name)
for attr_name in dir(source_module):
if not attr_name.startswith("__"):
setattr(target_module, attr_name, getattr(source_module, attr_name))
# Import sub-modules. Since these may import from here, this must come after
# any exported definitions.
from . import ir
# Add our 'dialects' parent module to the search path for implementations.
_cext.globals.append_dialect_search_prefix("mlir.dialects")

View File

@ -0,0 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Re-export the parent _cext so that every level of the API can get it locally.
from .. import _cext

View File

@ -0,0 +1,33 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# TODO: This file should be auto-generated.
from . import _cext
@_cext.register_dialect
class _Dialect(_cext.ir.Dialect):
# Special case: 'std' namespace aliases to the empty namespace.
DIALECT_NAMESPACE = "std"
pass
@_cext.register_operation(_Dialect)
class AddFOp(_cext.ir.OpView):
OPERATION_NAME = "std.addf"
def __init__(self, loc, lhs, rhs):
super().__init__(loc.context.create_operation(
"std.addf", loc, operands=[lhs, rhs], results=[lhs.type]))
@property
def lhs(self):
return self.operation.operands[0]
@property
def rhs(self):
return self.operation.operands[1]
@property
def result(self):
return self.operation.results[0]

View File

@ -0,0 +1,8 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Simply a wrapper around the extension module of the same name.
from . import _reexport_cext
_reexport_cext("ir", __name__)
del _reexport_cext

View File

@ -0,0 +1,107 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
import mlir
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert mlir.ir.Context._get_live_count() == 0
# CHECK-LABEL: TEST: testDialectDescriptor
def testDialectDescriptor():
ctx = mlir.ir.Context()
d = ctx.get_dialect_descriptor("std")
# CHECK: <DialectDescriptor std>
print(d)
# CHECK: std
print(d.namespace)
try:
_ = ctx.get_dialect_descriptor("not_existing")
except ValueError:
pass
else:
assert False, "Expected exception"
run(testDialectDescriptor)
# CHECK-LABEL: TEST: testUserDialectClass
def testUserDialectClass():
ctx = mlir.ir.Context()
# Access using attribute.
d = ctx.dialects.std
# Note that the standard dialect namespace prints as ''. Others will print
# as "<Dialect %namespace (..."
# CHECK: <Dialect (class mlir.dialects.std._Dialect)>
print(d)
try:
_ = ctx.dialects.not_existing
except AttributeError:
pass
else:
assert False, "Expected exception"
# Access using index.
d = ctx.dialects["std"]
# CHECK: <Dialect (class mlir.dialects.std._Dialect)>
print(d)
try:
_ = ctx.dialects["not_existing"]
except IndexError:
pass
else:
assert False, "Expected exception"
# Using the 'd' alias.
d = ctx.d["std"]
# CHECK: <Dialect (class mlir.dialects.std._Dialect)>
print(d)
run(testUserDialectClass)
# CHECK-LABEL: TEST: testCustomOpView
# This test uses the standard dialect AddFOp as an example of a user op.
# TODO: Op creation and access is still quite verbose: simplify this test as
# additional capabilities come online.
def testCustomOpView():
ctx = mlir.ir.Context()
ctx.allow_unregistered_dialects = True
f32 = mlir.ir.F32Type.get(ctx)
loc = ctx.get_unknown_location()
m = ctx.create_module(loc)
m_block = m.operation.regions[0].blocks[0]
# TODO: Remove integer insertion in favor of InsertionPoint and/or op-based.
ip = [0]
def createInput():
op = ctx.create_operation("pytest_dummy.intinput", loc, results=[f32])
m_block.operations.insert(ip[0], op)
ip[0] += 1
# TODO: Auto result cast from operation
return op.results[0]
# Create via dialects context collection.
input1 = createInput()
input2 = createInput()
op1 = ctx.dialects.std.AddFOp(loc, input1, input2)
# TODO: Auto operation cast from OpView
# TODO: Context manager insertion point
m_block.operations.insert(ip[0], op1.operation)
ip[0] += 1
# Create via an import
from mlir.dialects.std import AddFOp
op2 = AddFOp(loc, input1, op1.result)
m_block.operations.insert(ip[0], op2.operation)
ip[0] += 1
# CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
# CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
# CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
# CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
m.operation.print()
run(testCustomOpView)