Make MLIR Pass Timing output configurable through injection

This makes it possible for the client to control where the pass timings will
be printed.

Differential Revision: https://reviews.llvm.org/D78891
This commit is contained in:
Mehdi Amini 2020-04-27 23:38:17 +00:00
parent cd84bfb814
commit f65a3f7c83
3 changed files with 68 additions and 26 deletions

View File

@ -229,12 +229,38 @@ public:
//===--------------------------------------------------------------------===//
// Pass Timing
/// A configuration struct provided to the pass timing feature.
class PassTimingConfig {
public:
using PrintCallbackFn = function_ref<void(raw_ostream &)>;
/// Initialize the configuration.
/// * 'displayMode' switch between list or pipeline display (see the
/// `PassDisplayMode` enum documentation).
explicit PassTimingConfig(
PassDisplayMode displayMode = PassDisplayMode::Pipeline)
: displayMode(displayMode) {}
virtual ~PassTimingConfig();
/// A hook that may be overridden by a derived config to control the
/// printing. The callback is supplied by the framework and the config is
/// responsible to call it back with a stream for the output.
virtual void printTiming(PrintCallbackFn printCallback);
/// Return the `PassDisplayMode` this config was created with.
PassDisplayMode getDisplayMode() { return displayMode; }
private:
PassDisplayMode displayMode;
};
/// Add an instrumentation to time the execution of passes and the computation
/// of analyses.
/// Note: Timing should be enabled after all other instrumentations to avoid
/// any potential "ghost" timing from other instrumentations being
/// unintentionally included in the timing results.
void enableTiming(PassDisplayMode displayMode = PassDisplayMode::Pipeline);
void enableTiming(std::unique_ptr<PassTimingConfig> config = nullptr);
/// Prompts the pass manager to print the statistics collected for each of the
/// held passes after each call to 'run'.

View File

@ -141,7 +141,8 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
/// Add a pass timing instrumentation if enabled by 'pass-timing' flags.
void PassManagerOptions::addTimingInstrumentation(PassManager &pm) {
if (passTiming)
pm.enableTiming(passTimingDisplayMode);
pm.enableTiming(
std::make_unique<PassManager::PassTimingConfig>(passTimingDisplayMode));
}
void mlir::registerPassManagerCLOptions() {

View File

@ -160,7 +160,8 @@ struct Timer {
};
struct PassTiming : public PassInstrumentation {
PassTiming(PassDisplayMode displayMode) : displayMode(displayMode) {}
PassTiming(std::unique_ptr<PassManager::PassTimingConfig> config)
: config(std::move(config)) {}
~PassTiming() override { print(); }
/// Setup the instrumentation hooks.
@ -231,8 +232,8 @@ struct PassTiming : public PassInstrumentation {
/// A stack of the currently active pass timers per thread.
DenseMap<uint64_t, SmallVector<Timer *, 4>> activeThreadTimers;
/// The display mode to use when printing the timing results.
PassDisplayMode displayMode;
/// The configuration object to use when printing the timing results.
std::unique_ptr<PassManager::PassTimingConfig> config;
/// A mapping of pipeline timers that need to be merged into the parent
/// collection. The timers are mapped to the parent info to merge into.
@ -353,28 +354,37 @@ void PassTiming::print() {
return;
assert(rootTimers.size() == 1 && "expected one remaining root timer");
auto &rootTimer = rootTimers.begin()->second;
auto os = llvm::CreateInfoOutputFile();
auto printCallback = [&](raw_ostream &os) {
auto &rootTimer = rootTimers.begin()->second;
// Print the timer header.
TimeRecord totalTime = rootTimer->getTotalTime();
printTimerHeader(*os, totalTime);
printTimerHeader(os, totalTime);
// Defer to a specialized printer for each display mode.
switch (displayMode) {
switch (config->getDisplayMode()) {
case PassDisplayMode::List:
printResultsAsList(*os, rootTimer.get(), totalTime);
printResultsAsList(os, rootTimer.get(), totalTime);
break;
case PassDisplayMode::Pipeline:
printResultsAsPipeline(*os, rootTimer.get(), totalTime);
printResultsAsPipeline(os, rootTimer.get(), totalTime);
break;
}
printTimeEntry(*os, 0, "Total", totalTime, totalTime);
os->flush();
printTimeEntry(os, 0, "Total", totalTime, totalTime);
os.flush();
// Reset root timers.
rootTimers.clear();
activeThreadTimers.clear();
};
config->printTiming(printCallback);
}
// The default implementation for printTiming uses
// `llvm::CreateInfoOutputFile()` as stream, it can be overridden by clients
// to customize the output.
void PassManager::PassTimingConfig::printTiming(PrintCallbackFn printCallback) {
printCallback(*llvm::CreateInfoOutputFile());
}
/// Print the timing result in list mode.
@ -449,16 +459,21 @@ void PassTiming::printResultsAsPipeline(raw_ostream &os, Timer *root,
printTimer(0, topLevelTimer.second.get());
}
// Out-of-line as key function.
PassManager::PassTimingConfig::~PassTimingConfig() {}
//===----------------------------------------------------------------------===//
// PassManager
//===----------------------------------------------------------------------===//
/// Add an instrumentation to time the execution of passes and the computation
/// of analyses.
void PassManager::enableTiming(PassDisplayMode displayMode) {
void PassManager::enableTiming(std::unique_ptr<PassTimingConfig> config) {
// Check if pass timing is already enabled.
if (passTiming)
return;
addInstrumentation(std::make_unique<PassTiming>(displayMode));
if (!config)
config = std::make_unique<PassManager::PassTimingConfig>();
addInstrumentation(std::make_unique<PassTiming>(std::move(config)));
passTiming = true;
}