diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 3117d0f33dc6..98101b82f542 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -229,12 +229,38 @@ public: //===--------------------------------------------------------------------===// // Pass Timing + /// A configuration struct provided to the pass timing feature. + class PassTimingConfig { + public: + using PrintCallbackFn = function_ref; + + /// Initialize the configuration. + /// * 'displayMode' switch between list or pipeline display (see the + /// `PassDisplayMode` enum documentation). + explicit PassTimingConfig( + PassDisplayMode displayMode = PassDisplayMode::Pipeline) + : displayMode(displayMode) {} + + virtual ~PassTimingConfig(); + + /// A hook that may be overridden by a derived config to control the + /// printing. The callback is supplied by the framework and the config is + /// responsible to call it back with a stream for the output. + virtual void printTiming(PrintCallbackFn printCallback); + + /// Return the `PassDisplayMode` this config was created with. + PassDisplayMode getDisplayMode() { return displayMode; } + + private: + PassDisplayMode displayMode; + }; + /// Add an instrumentation to time the execution of passes and the computation /// of analyses. /// Note: Timing should be enabled after all other instrumentations to avoid /// any potential "ghost" timing from other instrumentations being /// unintentionally included in the timing results. - void enableTiming(PassDisplayMode displayMode = PassDisplayMode::Pipeline); + void enableTiming(std::unique_ptr config = nullptr); /// Prompts the pass manager to print the statistics collected for each of the /// held passes after each call to 'run'. diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp index e0c4df56cf72..953faa28c28a 100644 --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -141,7 +141,8 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) { /// Add a pass timing instrumentation if enabled by 'pass-timing' flags. void PassManagerOptions::addTimingInstrumentation(PassManager &pm) { if (passTiming) - pm.enableTiming(passTimingDisplayMode); + pm.enableTiming( + std::make_unique(passTimingDisplayMode)); } void mlir::registerPassManagerCLOptions() { diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp index 663cbdad7c39..c8f0ad8afa50 100644 --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -160,7 +160,8 @@ struct Timer { }; struct PassTiming : public PassInstrumentation { - PassTiming(PassDisplayMode displayMode) : displayMode(displayMode) {} + PassTiming(std::unique_ptr config) + : config(std::move(config)) {} ~PassTiming() override { print(); } /// Setup the instrumentation hooks. @@ -231,8 +232,8 @@ struct PassTiming : public PassInstrumentation { /// A stack of the currently active pass timers per thread. DenseMap> activeThreadTimers; - /// The display mode to use when printing the timing results. - PassDisplayMode displayMode; + /// The configuration object to use when printing the timing results. + std::unique_ptr config; /// A mapping of pipeline timers that need to be merged into the parent /// collection. The timers are mapped to the parent info to merge into. @@ -353,28 +354,37 @@ void PassTiming::print() { return; assert(rootTimers.size() == 1 && "expected one remaining root timer"); - auto &rootTimer = rootTimers.begin()->second; - auto os = llvm::CreateInfoOutputFile(); - // Print the timer header. - TimeRecord totalTime = rootTimer->getTotalTime(); - printTimerHeader(*os, totalTime); + auto printCallback = [&](raw_ostream &os) { + auto &rootTimer = rootTimers.begin()->second; + // Print the timer header. + TimeRecord totalTime = rootTimer->getTotalTime(); + printTimerHeader(os, totalTime); + // Defer to a specialized printer for each display mode. + switch (config->getDisplayMode()) { + case PassDisplayMode::List: + printResultsAsList(os, rootTimer.get(), totalTime); + break; + case PassDisplayMode::Pipeline: + printResultsAsPipeline(os, rootTimer.get(), totalTime); + break; + } + printTimeEntry(os, 0, "Total", totalTime, totalTime); + os.flush(); - // Defer to a specialized printer for each display mode. - switch (displayMode) { - case PassDisplayMode::List: - printResultsAsList(*os, rootTimer.get(), totalTime); - break; - case PassDisplayMode::Pipeline: - printResultsAsPipeline(*os, rootTimer.get(), totalTime); - break; - } - printTimeEntry(*os, 0, "Total", totalTime, totalTime); - os->flush(); + // Reset root timers. + rootTimers.clear(); + activeThreadTimers.clear(); + }; - // Reset root timers. - rootTimers.clear(); - activeThreadTimers.clear(); + config->printTiming(printCallback); +} + +// The default implementation for printTiming uses +// `llvm::CreateInfoOutputFile()` as stream, it can be overridden by clients +// to customize the output. +void PassManager::PassTimingConfig::printTiming(PrintCallbackFn printCallback) { + printCallback(*llvm::CreateInfoOutputFile()); } /// Print the timing result in list mode. @@ -449,16 +459,21 @@ void PassTiming::printResultsAsPipeline(raw_ostream &os, Timer *root, printTimer(0, topLevelTimer.second.get()); } +// Out-of-line as key function. +PassManager::PassTimingConfig::~PassTimingConfig() {} + //===----------------------------------------------------------------------===// // PassManager //===----------------------------------------------------------------------===// /// Add an instrumentation to time the execution of passes and the computation /// of analyses. -void PassManager::enableTiming(PassDisplayMode displayMode) { +void PassManager::enableTiming(std::unique_ptr config) { // Check if pass timing is already enabled. if (passTiming) return; - addInstrumentation(std::make_unique(displayMode)); + if (!config) + config = std::make_unique(); + addInstrumentation(std::make_unique(std::move(config))); passTiming = true; }