forked from OSchip/llvm-project
125 lines
4.9 KiB
C++
125 lines
4.9 KiB
C++
//===- Pass.cpp - Pass Management -----------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Pass.h"
|
|
|
|
#include "IRModule.h"
|
|
#include "mlir-c/Bindings/Python/Interop.h"
|
|
#include "mlir-c/Pass.h"
|
|
|
|
namespace py = pybind11;
|
|
using namespace mlir;
|
|
using namespace mlir::python;
|
|
|
|
namespace {
|
|
|
|
/// Owning Wrapper around a PassManager.
|
|
class PyPassManager {
|
|
public:
|
|
PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
|
|
PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
|
|
other.passManager.ptr = nullptr;
|
|
}
|
|
~PyPassManager() {
|
|
if (!mlirPassManagerIsNull(passManager))
|
|
mlirPassManagerDestroy(passManager);
|
|
}
|
|
MlirPassManager get() { return passManager; }
|
|
|
|
void release() { passManager.ptr = nullptr; }
|
|
pybind11::object getCapsule() {
|
|
return py::reinterpret_steal<py::object>(
|
|
mlirPythonPassManagerToCapsule(get()));
|
|
}
|
|
|
|
static pybind11::object createFromCapsule(pybind11::object capsule) {
|
|
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
|
|
if (mlirPassManagerIsNull(rawPm))
|
|
throw py::error_already_set();
|
|
return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
|
|
}
|
|
|
|
private:
|
|
MlirPassManager passManager;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
/// Create the `mlir.passmanager` here.
|
|
void mlir::python::populatePassManagerSubmodule(py::module &m) {
|
|
//----------------------------------------------------------------------------
|
|
// Mapping of the top-level PassManager
|
|
//----------------------------------------------------------------------------
|
|
py::class_<PyPassManager>(m, "PassManager", py::module_local())
|
|
.def(py::init<>([](DefaultingPyMlirContext context) {
|
|
MlirPassManager passManager =
|
|
mlirPassManagerCreate(context->get());
|
|
return new PyPassManager(passManager);
|
|
}),
|
|
py::arg("context") = py::none(),
|
|
"Create a new PassManager for the current (or provided) Context.")
|
|
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
|
&PyPassManager::getCapsule)
|
|
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
|
|
.def("_testing_release", &PyPassManager::release,
|
|
"Releases (leaks) the backing pass manager (testing)")
|
|
.def(
|
|
"enable_ir_printing",
|
|
[](PyPassManager &passManager) {
|
|
mlirPassManagerEnableIRPrinting(passManager.get());
|
|
},
|
|
"Enable mlir-print-ir-after-all.")
|
|
.def(
|
|
"enable_verifier",
|
|
[](PyPassManager &passManager, bool enable) {
|
|
mlirPassManagerEnableVerifier(passManager.get(), enable);
|
|
},
|
|
py::arg("enable"), "Enable / disable verify-each.")
|
|
.def_static(
|
|
"parse",
|
|
[](const std::string pipeline, DefaultingPyMlirContext context) {
|
|
MlirPassManager passManager = mlirPassManagerCreate(context->get());
|
|
MlirLogicalResult status = mlirParsePassPipeline(
|
|
mlirPassManagerGetAsOpPassManager(passManager),
|
|
mlirStringRefCreate(pipeline.data(), pipeline.size()));
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw SetPyError(PyExc_ValueError,
|
|
llvm::Twine("invalid pass pipeline '") +
|
|
pipeline + "'.");
|
|
return new PyPassManager(passManager);
|
|
},
|
|
py::arg("pipeline"), py::arg("context") = py::none(),
|
|
"Parse a textual pass-pipeline and return a top-level PassManager "
|
|
"that can be applied on a Module. Throw a ValueError if the pipeline "
|
|
"can't be parsed")
|
|
.def(
|
|
"run",
|
|
[](PyPassManager &passManager, PyModule &module) {
|
|
MlirLogicalResult status =
|
|
mlirPassManagerRun(passManager.get(), module.get());
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw SetPyError(PyExc_RuntimeError,
|
|
"Failure while executing pass pipeline.");
|
|
},
|
|
py::arg("module"),
|
|
"Run the pass manager on the provided module, throw a RuntimeError "
|
|
"on failure.")
|
|
.def(
|
|
"__str__",
|
|
[](PyPassManager &self) {
|
|
MlirPassManager passManager = self.get();
|
|
PyPrintAccumulator printAccum;
|
|
mlirPrintPassPipeline(
|
|
mlirPassManagerGetAsOpPassManager(passManager),
|
|
printAccum.getCallback(), printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
"Print the textual representation for this PassManager, suitable to "
|
|
"be passed to `parse` for round-tripping.");
|
|
}
|