forked from OSchip/llvm-project
ExecutionEngine: allow for running MLIR passes during JIT-compilation
The existing implementation of the ExecutionEngine unconditionally runs a list of "default" MLIR passes on the module upon creation. These passes include, among others, dialect conversions from affine to standard and from standard to LLVM IR dialects. In some cases, these conversions might have been performed before ExecutionEngine is created. More advanced use cases may be performing additional transformations that the "default" passes will conflict with. Provide an overload for ExecutionEngine::create that takes a PassManager configured with the passes to run on the module. If it is not provided, do not run any passes. The engine will not be created if the input module, after the pass manager, has any other dialect than the LLVM IR dialect. -- PiperOrigin-RevId: 242127393
This commit is contained in:
parent
7a640e65e9
commit
33285de937
|
@ -37,6 +37,7 @@ class Module;
|
|||
namespace mlir {
|
||||
|
||||
class Module;
|
||||
class PassManager;
|
||||
|
||||
namespace impl {
|
||||
class OrcJIT;
|
||||
|
@ -56,6 +57,14 @@ class ExecutionEngine {
|
|||
public:
|
||||
~ExecutionEngine();
|
||||
|
||||
/// Creates an execution engine for the given module. If `pm` is provided,
|
||||
/// runs it on the MLIR module. If `transformer` is
|
||||
/// provided, it will be called on the LLVM module during JIT-compilation and
|
||||
/// can be used, e.g., for reporting or optimization.
|
||||
static llvm::Expected<std::unique_ptr<ExecutionEngine>>
|
||||
create(Module *m, PassManager *pm,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer = {});
|
||||
|
||||
/// Creates an execution engine for the given module. If `transformer` is
|
||||
/// provided, it will be called on the LLVM module during JIT-compilation and
|
||||
/// can be used, e.g., for reporting or optimization.
|
||||
|
|
|
@ -34,12 +34,11 @@ namespace mlir {
|
|||
|
||||
class Module;
|
||||
|
||||
/// Convert the given MLIR module into LLVM IR. Create an LLVM IR module in
|
||||
/// "llvmContext" and return a unique pointer to it. In case of error, report it
|
||||
/// Convert the given MLIR module into LLVM IR. The LLVM context is extracted
|
||||
/// from the registered LLVM IR dialect. In case of error, report it
|
||||
/// to the error handler registered with the MLIR context, if any (obtained from
|
||||
/// the MLIR module), and return `nullptr`.
|
||||
std::unique_ptr<llvm::Module>
|
||||
convertModuleToLLVMIR(Module &module, llvm::LLVMContext &llvmContext);
|
||||
std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -273,19 +273,15 @@ void packFunctionArguments(llvm::Module *module) {
|
|||
// Out of line for PIMPL unique_ptr.
|
||||
ExecutionEngine::~ExecutionEngine() = default;
|
||||
|
||||
std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m);
|
||||
|
||||
Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
|
||||
Module *m, std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
Module *m, PassManager *pm,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
auto engine = llvm::make_unique<ExecutionEngine>();
|
||||
auto expectedJIT = impl::OrcJIT::createDefault(transformer);
|
||||
if (!expectedJIT)
|
||||
return expectedJIT.takeError();
|
||||
|
||||
// Construct and run the default MLIR pipeline.
|
||||
PassManager manager;
|
||||
getDefaultPasses(manager, {});
|
||||
if (failed(manager.run(m)))
|
||||
if (pm && failed(pm->run(m)))
|
||||
return make_string_error("passes failed");
|
||||
|
||||
auto llvmModule = translateModuleToLLVMIR(*m);
|
||||
|
@ -304,6 +300,14 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
|
|||
return std::move(engine);
|
||||
}
|
||||
|
||||
Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
|
||||
Module *m, std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
// Construct and run the default MLIR pipeline.
|
||||
PassManager manager;
|
||||
getDefaultPasses(manager, {});
|
||||
return create(m, &manager, transformer);
|
||||
}
|
||||
|
||||
Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
|
||||
auto expectedSymbol = jit->lookup(makePackedFunctionName(name));
|
||||
if (!expectedSymbol)
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Target/LLVMIR.h"
|
||||
#include "mlir/Translation.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
@ -441,7 +442,6 @@ bool ModuleTranslation::convertFunctions() {
|
|||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> ModuleTranslation::translateModule(Module &m) {
|
||||
|
||||
Dialect *dialect = m.getContext()->getRegisteredDialect("llvm");
|
||||
assert(dialect && "LLVM dialect must be registered");
|
||||
auto *llvmDialect = static_cast<LLVM::LLVMDialect *>(dialect);
|
||||
|
@ -468,7 +468,7 @@ std::unique_ptr<llvm::Module> ModuleTranslation::translateModule(Module &m) {
|
|||
return std::move(translator.llvmModule);
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m) {
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module &m) {
|
||||
return ModuleTranslation::translateModule(m);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue