forked from OSchip/llvm-project
[mlir] JitRunner: add a config option to register symbols with ExecutionEngine at runtime
Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D90264
This commit is contained in:
parent
50dfa19cc7
commit
f6c9f6eccd
|
@ -18,29 +18,42 @@
|
|||
#ifndef MLIR_SUPPORT_JITRUNNER_H_
|
||||
#define MLIR_SUPPORT_JITRUNNER_H_
|
||||
|
||||
#include "mlir/IR/Module.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/ExecutionEngine/Orc/Core.h"
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
|
||||
namespace orc {
|
||||
class MangleAndInterner;
|
||||
} // namespace orc
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
||||
using TranslationCallback = llvm::function_ref<std::unique_ptr<llvm::Module>(
|
||||
ModuleOp, llvm::LLVMContext &)>;
|
||||
|
||||
class ModuleOp;
|
||||
struct LogicalResult;
|
||||
|
||||
struct JitRunnerConfig {
|
||||
/// MLIR transformer applied after parsing the input into MLIR IR and before
|
||||
/// passing the MLIR module to the ExecutionEngine.
|
||||
llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer = nullptr;
|
||||
|
||||
/// A custom function that is passed to ExecutionEngine. It processes MLIR
|
||||
/// module and creates LLVM IR module.
|
||||
llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
|
||||
llvm::LLVMContext &)>
|
||||
llvmModuleBuilder = nullptr;
|
||||
|
||||
/// A callback to register symbols with ExecutionEngine at runtime.
|
||||
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
|
||||
runtimesymbolMap = nullptr;
|
||||
};
|
||||
|
||||
// Entry point for all CPU runners. Expects the common argc/argv arguments for
|
||||
// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
|
||||
/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
|
||||
/// passing the MLIR module to the ExecutionEngine.
|
||||
/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
|
||||
/// It processes MLIR module and creates LLVM IR module.
|
||||
int JitRunnerMain(
|
||||
int argc, char **argv,
|
||||
llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
|
||||
TranslationCallback llvmModuleBuilder = nullptr);
|
||||
// standard C++ main functions.
|
||||
int JitRunnerMain(int argc, char **argv, JitRunnerConfig config = {});
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -92,6 +92,23 @@ struct Options {
|
|||
"object-filename",
|
||||
llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
|
||||
};
|
||||
|
||||
struct CompileAndExecuteConfig {
|
||||
/// LLVM module transformer that is passed to ExecutionEngine.
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
|
||||
|
||||
/// A custom function that is passed to ExecutionEngine. It processes MLIR
|
||||
/// module and creates LLVM IR module.
|
||||
llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
|
||||
llvm::LLVMContext &)>
|
||||
llvmModuleBuilder;
|
||||
|
||||
/// A custom function that is passed to ExecutinEngine to register symbols at
|
||||
/// runtime.
|
||||
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
|
||||
runtimeSymbolMap;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
static OwningModuleRef parseMLIRInput(StringRef inputFilename,
|
||||
|
@ -131,11 +148,9 @@ static Optional<unsigned> getCommandLineOptLevel(Options &options) {
|
|||
}
|
||||
|
||||
// JIT-compile the given module and run "entryPoint" with "args" as arguments.
|
||||
static Error
|
||||
compileAndExecute(Options &options, ModuleOp module,
|
||||
TranslationCallback llvmModuleBuilder, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer,
|
||||
void **args) {
|
||||
static Error compileAndExecute(Options &options, ModuleOp module,
|
||||
StringRef entryPoint,
|
||||
CompileAndExecuteConfig config, void **args) {
|
||||
Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
|
||||
if (auto clOptLevel = getCommandLineOptLevel(options))
|
||||
jitCodeGenOptLevel =
|
||||
|
@ -143,11 +158,15 @@ compileAndExecute(Options &options, ModuleOp module,
|
|||
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
|
||||
options.clSharedLibs.end());
|
||||
auto expectedEngine = mlir::ExecutionEngine::create(
|
||||
module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs);
|
||||
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
|
||||
libs);
|
||||
if (!expectedEngine)
|
||||
return expectedEngine.takeError();
|
||||
|
||||
auto engine = std::move(*expectedEngine);
|
||||
if (config.runtimeSymbolMap)
|
||||
engine->registerSymbols(config.runtimeSymbolMap);
|
||||
|
||||
auto expectedFPtr = engine->lookup(entryPoint);
|
||||
if (!expectedFPtr)
|
||||
return expectedFPtr.takeError();
|
||||
|
@ -163,16 +182,14 @@ compileAndExecute(Options &options, ModuleOp module,
|
|||
return Error::success();
|
||||
}
|
||||
|
||||
static Error compileAndExecuteVoidFunction(
|
||||
Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
|
||||
StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
|
||||
StringRef entryPoint,
|
||||
CompileAndExecuteConfig config) {
|
||||
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
||||
if (!mainFunction || mainFunction.empty())
|
||||
return make_string_error("entry point not found");
|
||||
void *empty = nullptr;
|
||||
return compileAndExecute(options, module, llvmModuleBuilder, entryPoint,
|
||||
transformer, &empty);
|
||||
return compileAndExecute(options, module, entryPoint, config, &empty);
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
|
@ -196,10 +213,9 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
|
|||
return Error::success();
|
||||
}
|
||||
template <typename Type>
|
||||
Error compileAndExecuteSingleReturnFunction(
|
||||
Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
|
||||
StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
|
||||
StringRef entryPoint,
|
||||
CompileAndExecuteConfig config) {
|
||||
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
||||
if (!mainFunction || mainFunction.isExternal())
|
||||
return make_string_error("entry point not found");
|
||||
|
@ -215,8 +231,8 @@ Error compileAndExecuteSingleReturnFunction(
|
|||
void *data;
|
||||
} data;
|
||||
data.data = &res;
|
||||
if (auto error = compileAndExecute(options, module, llvmModuleBuilder,
|
||||
entryPoint, transformer, (void **)&data))
|
||||
if (auto error = compileAndExecute(options, module, entryPoint, config,
|
||||
(void **)&data))
|
||||
return error;
|
||||
|
||||
// Intentional printing of the output so we can test.
|
||||
|
@ -226,15 +242,8 @@ Error compileAndExecuteSingleReturnFunction(
|
|||
}
|
||||
|
||||
/// Entry point for all CPU runners. Expects the common argc/argv arguments for
|
||||
/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
|
||||
/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
|
||||
/// passing the MLIR module to the ExecutionEngine.
|
||||
/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
|
||||
/// It processes MLIR module and creates LLVM IR module.
|
||||
int mlir::JitRunnerMain(
|
||||
int argc, char **argv,
|
||||
function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
|
||||
TranslationCallback llvmModuleBuilder) {
|
||||
/// standard C++ main functions.
|
||||
int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
|
||||
// Create the options struct containing the command line options for the
|
||||
// runner. This must come before the command line options are parsed.
|
||||
Options options;
|
||||
|
@ -274,8 +283,8 @@ int mlir::JitRunnerMain(
|
|||
return 1;
|
||||
}
|
||||
|
||||
if (mlirTransformer)
|
||||
if (failed(mlirTransformer(m.get())))
|
||||
if (config.mlirTransformer)
|
||||
if (failed(config.mlirTransformer(m.get())))
|
||||
return EXIT_FAILURE;
|
||||
|
||||
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||
|
@ -292,10 +301,14 @@ int mlir::JitRunnerMain(
|
|||
auto transformer = mlir::makeLLVMPassesTransformer(
|
||||
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
|
||||
|
||||
CompileAndExecuteConfig compileAndExecuteConfig;
|
||||
compileAndExecuteConfig.transformer = transformer;
|
||||
compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
|
||||
compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
|
||||
|
||||
// Get the function used to compile and execute the module.
|
||||
using CompileAndExecuteFnT =
|
||||
Error (*)(Options &, ModuleOp, TranslationCallback, StringRef,
|
||||
std::function<llvm::Error(llvm::Module *)>);
|
||||
Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
|
||||
auto compileAndExecuteFn =
|
||||
StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
|
||||
.Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
|
||||
|
@ -304,11 +317,11 @@ int mlir::JitRunnerMain(
|
|||
.Case("void", compileAndExecuteVoidFunction)
|
||||
.Default(nullptr);
|
||||
|
||||
Error error =
|
||||
compileAndExecuteFn
|
||||
? compileAndExecuteFn(options, m.get(), llvmModuleBuilder,
|
||||
options.mainFuncName.getValue(), transformer)
|
||||
: make_string_error("unsupported function type");
|
||||
Error error = compileAndExecuteFn
|
||||
? compileAndExecuteFn(options, m.get(),
|
||||
options.mainFuncName.getValue(),
|
||||
compileAndExecuteConfig)
|
||||
: make_string_error("unsupported function type");
|
||||
|
||||
int exitCode = EXIT_SUCCESS;
|
||||
llvm::handleAllErrors(std::move(error),
|
||||
|
|
|
@ -24,5 +24,5 @@ int main(int argc, char **argv) {
|
|||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::initializeLLVMPasses();
|
||||
|
||||
return mlir::JitRunnerMain(argc, argv, nullptr);
|
||||
return mlir::JitRunnerMain(argc, argv);
|
||||
}
|
||||
|
|
|
@ -136,5 +136,9 @@ int main(int argc, char **argv) {
|
|||
LLVMInitializeNVPTXAsmPrinter();
|
||||
|
||||
mlir::initializeLLVMPasses();
|
||||
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
|
||||
|
||||
mlir::JitRunnerConfig jitRunnerConfig;
|
||||
jitRunnerConfig.mlirTransformer = &runMLIRPasses;
|
||||
|
||||
return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
|
||||
}
|
||||
|
|
|
@ -86,5 +86,9 @@ int main(int argc, char **argv) {
|
|||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::initializeLLVMPasses();
|
||||
|
||||
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule);
|
||||
mlir::JitRunnerConfig jitRunnerConfig;
|
||||
jitRunnerConfig.mlirTransformer = &runMLIRPasses;
|
||||
jitRunnerConfig.llvmModuleBuilder = &convertMLIRModule;
|
||||
|
||||
return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
|
||||
}
|
||||
|
|
|
@ -58,5 +58,8 @@ int main(int argc, char **argv) {
|
|||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::initializeLLVMPasses();
|
||||
|
||||
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
|
||||
mlir::JitRunnerConfig jitRunnerConfig;
|
||||
jitRunnerConfig.mlirTransformer = &runMLIRPasses;
|
||||
|
||||
return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue