ExecutionEngine: allow for running MLIR passes during JIT-compilation

The existing implementation of the ExecutionEngine unconditionally runs a list
    of "default" MLIR passes on the module upon creation.  These passes include,
    among others, dialect conversions from affine to standard and from standard to
    LLVM IR dialects.  In some cases, these conversions might have been performed
    before ExecutionEngine is created.  More advanced use cases may be performing
    additional transformations that the "default" passes will conflict with.
    Provide an overload for ExecutionEngine::create that takes a PassManager
    configured with the passes to run on the module.  If it is not provided, do not
    run any passes.  The engine will not be created if the input module, after the
    pass manager, has any other dialect than the LLVM IR dialect.

--

PiperOrigin-RevId: 242127393
This commit is contained in:
Alex Zinenko 2019-04-05 08:15:41 -07:00 committed by Mehdi Amini
parent 7a640e65e9
commit 33285de937
4 changed files with 25 additions and 13 deletions

View File

@ -37,6 +37,7 @@ class Module;
namespace mlir {
class Module;
class PassManager;
namespace impl {
class OrcJIT;
@ -56,6 +57,14 @@ class ExecutionEngine {
public:
~ExecutionEngine();
/// Creates an execution engine for the given module. If `pm` is provided,
/// runs it on the MLIR module. If `transformer` is
/// provided, it will be called on the LLVM module during JIT-compilation and
/// can be used, e.g., for reporting or optimization.
static llvm::Expected<std::unique_ptr<ExecutionEngine>>
create(Module *m, PassManager *pm,
std::function<llvm::Error(llvm::Module *)> transformer = {});
/// Creates an execution engine for the given module. If `transformer` is
/// provided, it will be called on the LLVM module during JIT-compilation and
/// can be used, e.g., for reporting or optimization.

View File

@ -34,12 +34,11 @@ namespace mlir {
class Module;
/// Convert the given MLIR module into LLVM IR. Create an LLVM IR module in
/// "llvmContext" and return a unique pointer to it. In case of error, report it
/// Convert the given MLIR module into LLVM IR. The LLVM context is extracted
/// from the registered LLVM IR dialect. In case of error, report it
/// to the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`.
std::unique_ptr<llvm::Module>
convertModuleToLLVMIR(Module &module, llvm::LLVMContext &llvmContext);
std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m);
} // namespace mlir

View File

@ -273,19 +273,15 @@ void packFunctionArguments(llvm::Module *module) {
// Out of line for PIMPL unique_ptr.
ExecutionEngine::~ExecutionEngine() = default;
std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m);
Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
Module *m, std::function<llvm::Error(llvm::Module *)> transformer) {
Module *m, PassManager *pm,
std::function<llvm::Error(llvm::Module *)> transformer) {
auto engine = llvm::make_unique<ExecutionEngine>();
auto expectedJIT = impl::OrcJIT::createDefault(transformer);
if (!expectedJIT)
return expectedJIT.takeError();
// Construct and run the default MLIR pipeline.
PassManager manager;
getDefaultPasses(manager, {});
if (failed(manager.run(m)))
if (pm && failed(pm->run(m)))
return make_string_error("passes failed");
auto llvmModule = translateModuleToLLVMIR(*m);
@ -304,6 +300,14 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
return std::move(engine);
}
Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
Module *m, std::function<llvm::Error(llvm::Module *)> transformer) {
// Construct and run the default MLIR pipeline.
PassManager manager;
getDefaultPasses(manager, {});
return create(m, &manager, transformer);
}
Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
auto expectedSymbol = jit->lookup(makePackedFunctionName(name));
if (!expectedSymbol)

View File

@ -25,6 +25,7 @@
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Translation.h"
#include "llvm/ADT/SetVector.h"
@ -441,7 +442,6 @@ bool ModuleTranslation::convertFunctions() {
}
std::unique_ptr<llvm::Module> ModuleTranslation::translateModule(Module &m) {
Dialect *dialect = m.getContext()->getRegisteredDialect("llvm");
assert(dialect && "LLVM dialect must be registered");
auto *llvmDialect = static_cast<LLVM::LLVMDialect *>(dialect);
@ -468,7 +468,7 @@ std::unique_ptr<llvm::Module> ModuleTranslation::translateModule(Module &m) {
return std::move(translator.llvmModule);
}
std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m) {
std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module &m) {
return ModuleTranslation::translateModule(m);
}