[mlir] Add callback to provide a pass pipeline to MlirOptMain

The callback can be used to provide a default pass pipeline.

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D113144
This commit is contained in:
Deepak Panickal 2021-11-05 17:39:57 +00:00 committed by Mehdi Amini
parent 5c3d7184b4
commit 97c899f3c5
2 changed files with 47 additions and 16 deletions

View File

@ -27,6 +27,12 @@ class MemoryBuffer;
namespace mlir {
class DialectRegistry;
class PassPipelineCLParser;
class PassManager;
/// This defines the function type used to setup the pass manager. This can be
/// used to pass in a callback to setup a default pass pipeline to be applied on
/// the loaded IR.
using PassPipelineFn = llvm::function_ref<LogicalResult(PassManager &pm)>;
/// Perform the core processing behind `mlir-opt`:
/// - outputStream is the stream where the resulting IR is printed.
@ -52,6 +58,17 @@ LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
bool allowUnregisteredDialects,
bool preloadDialectsInContext = false);
/// Support a callback to setup the pass manager.
/// - passManagerSetupFn is the callback invoked to setup the pass manager to
/// apply on the loaded IR.
LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
std::unique_ptr<llvm::MemoryBuffer> buffer,
PassPipelineFn passManagerSetupFn,
DialectRegistry &registry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext = false);
/// Implementation for tools like `mlir-opt`.
/// - toolName is used for the header displayed by `--help`.
/// - registry should contain all the dialects that can be parsed in the source.

View File

@ -48,7 +48,7 @@ using llvm::SMLoc;
static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
bool verifyPasses, SourceMgr &sourceMgr,
MLIRContext *context,
const PassPipelineCLParser &passPipeline) {
PassPipelineFn passManagerSetupFn) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
@ -72,13 +72,8 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
applyPassManagerCLOptions(pm);
pm.enableTiming(timing);
auto errorHandler = [&](const Twine &msg) {
emitError(UnknownLoc::get(context)) << msg;
return failure();
};
// Build the provided pipeline.
if (failed(passPipeline.addToPipeline(pm, errorHandler)))
// Callback to build the pipeline.
if (failed(passManagerSetupFn(pm)))
return failure();
// Run the pipeline.
@ -98,8 +93,8 @@ static LogicalResult
processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
const PassPipelineCLParser &passPipeline,
DialectRegistry &registry, llvm::ThreadPool &threadPool) {
PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
llvm::ThreadPool &threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@ -122,7 +117,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
&context, passPipeline);
&context, passManagerSetupFn);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
@ -131,7 +126,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
passPipeline);
passManagerSetupFn);
// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
@ -140,7 +135,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
std::unique_ptr<MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline,
PassPipelineFn passManagerSetupFn,
DialectRegistry &registry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
@ -156,17 +151,36 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
[&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
preloadDialectsInContext, passPipeline, registry,
threadPool);
preloadDialectsInContext, passManagerSetupFn,
registry, threadPool);
},
outputStream);
return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
preloadDialectsInContext, passPipeline, registry,
preloadDialectsInContext, passManagerSetupFn, registry,
threadPool);
}
LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
std::unique_ptr<MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline,
DialectRegistry &registry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext) {
auto passManagerSetupFn = [&](PassManager &pm) {
auto errorHandler = [&](const Twine &msg) {
emitError(UnknownLoc::get(pm.getContext())) << msg;
return failure();
};
return passPipeline.addToPipeline(pm, errorHandler);
};
return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
registry, splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects, preloadDialectsInContext);
}
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
DialectRegistry &registry,
bool preloadDialectsInContext) {