Refactor the IRPrinting instrumentation to take a derivable config.

This allows for more interesting behavior from users, e.g. enabling the ability to dump the IR to a separate file for each pass invocation.

PiperOrigin-RevId: 284059447
This commit is contained in:
River Riddle 2019-12-05 14:52:28 -08:00 committed by A. Unique TensorFlower
parent daff60cd68
commit da53000fb4
3 changed files with 159 additions and 56 deletions

View File

@ -159,16 +159,63 @@ public:
/// Add the provided instrumentation to the pass manager.
void addInstrumentation(std::unique_ptr<PassInstrumentation> pi);
/// Add an instrumentation to print the IR before and after pass execution.
//===--------------------------------------------------------------------===//
// IR Printing
/// A configuration struct provided to the IR printer instrumentation.
class IRPrinterConfig {
public:
using PrintCallbackFn = function_ref<void(raw_ostream &)>;
/// Initialize the configuration.
/// * 'printModuleScope' signals if the top-level module IR should always be
/// printed. This should only be set to true when multi-threading is
/// disabled, otherwise we may try to print IR that is being modified
/// asynchronously.
explicit IRPrinterConfig(bool printModuleScope = false);
virtual ~IRPrinterConfig();
/// A hook that may be overridden by a derived config that checks if the IR
/// of 'operation' should be dumped *before* the pass 'pass' has been
/// executed. If the IR should be dumped, 'printCallback' should be invoked
/// with the stream to dump into.
virtual void printBeforeIfEnabled(Pass *pass, Operation *operation,
PrintCallbackFn printCallback);
/// A hook that may be overridden by a derived config that checks if the IR
/// of 'operation' should be dumped *after* the pass 'pass' has been
/// executed. If the IR should be dumped, 'printCallback' should be invoked
/// with the stream to dump into.
virtual void printAfterIfEnabled(Pass *pass, Operation *operation,
PrintCallbackFn printCallback);
/// Returns true if the IR should always be printed at the top-level scope.
bool shouldPrintAtModuleScope() const { return printModuleScope; }
private:
/// A flag that indicates if the IR should be printed at module scope.
bool printModuleScope;
};
/// Add an instrumentation to print the IR before and after pass execution,
/// using the provided configuration.
void enableIRPrinting(std::unique_ptr<IRPrinterConfig> config);
/// Add an instrumentation to print the IR before and after pass execution,
/// using the provided fields to generate a default configuration:
/// * 'shouldPrintBeforePass' and 'shouldPrintAfterPass' correspond to filter
/// functions that take a 'Pass *'. These function should return true if the
/// IR should be printed or not.
/// * 'printModuleScope' signals if the module IR should be printed, even for
/// non module passes.
/// functions that take a 'Pass *' and `Operation *`. These function should
/// return true if the IR should be printed or not.
/// * 'printModuleScope' signals if the module IR should be printed, even
/// for non module passes.
/// * 'out' corresponds to the stream to output the printed IR to.
void enableIRPrinting(std::function<bool(Pass *)> shouldPrintBeforePass,
std::function<bool(Pass *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out);
void enableIRPrinting(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out);
//===--------------------------------------------------------------------===//
// Pass Timing
/// Add an instrumentation to time the execution of passes and the computation
/// of analyses.

View File

@ -27,19 +27,8 @@ using namespace mlir::detail;
namespace {
class IRPrinterInstrumentation : public PassInstrumentation {
public:
/// A filter function to decide if the given pass should be printed. Returns
/// true if the pass should be printed, false otherwise.
using ShouldPrintFn = std::function<bool(Pass *)>;
IRPrinterInstrumentation(ShouldPrintFn &&shouldPrintBeforePass,
ShouldPrintFn &&shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out)
: shouldPrintBeforePass(shouldPrintBeforePass),
shouldPrintAfterPass(shouldPrintAfterPass),
printModuleScope(printModuleScope), out(out) {
assert((shouldPrintBeforePass || shouldPrintAfterPass) &&
"expected atleast one valid filter function");
}
IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
: config(std::move(config)) {}
private:
/// Instrumentation hooks.
@ -47,14 +36,8 @@ private:
void runAfterPass(Pass *pass, Operation *op) override;
void runAfterPassFailed(Pass *pass, Operation *op) override;
/// Filter functions for before and after pass execution.
ShouldPrintFn shouldPrintBeforePass, shouldPrintAfterPass;
/// Flag to toggle if the printer should always print at module scope.
bool printModuleScope;
/// The stream to output to.
raw_ostream &out;
/// Configuration to use.
std::unique_ptr<PassManager::IRPrinterConfig> config;
};
} // end anonymous namespace
@ -96,45 +79,117 @@ static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
/// Instrumentation hooks.
void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
// Skip hidden passes and passes that the user filtered out.
if (!shouldPrintBeforePass || isHiddenPass(pass) ||
!shouldPrintBeforePass(pass))
if (isHiddenPass(pass))
return;
out << formatv("*** IR Dump Before {0} ***", pass->getName());
printIR(op, printModuleScope, out, OpPrintingFlags());
out << "\n\n";
config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump Before {0} ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
out << "\n\n";
});
}
void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
// Skip hidden passes and passes that the user filtered out.
if (!shouldPrintAfterPass || isHiddenPass(pass) ||
!shouldPrintAfterPass(pass))
if (isHiddenPass(pass))
return;
out << formatv("*** IR Dump After {0} ***", pass->getName());
printIR(op, printModuleScope, out, OpPrintingFlags());
out << "\n\n";
config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump After {0} ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
out << "\n\n";
});
}
void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
// Skip adaptor passes and passes that the user filtered out.
if (!shouldPrintAfterPass || isAdaptorPass(pass) ||
!shouldPrintAfterPass(pass))
if (isAdaptorPass(pass))
return;
out << formatv("*** IR Dump After {0} Failed ***", pass->getName());
printIR(op, printModuleScope, out, OpPrintingFlags().printGenericOpForm());
out << "\n\n";
config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump After {0} Failed ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out,
OpPrintingFlags().printGenericOpForm());
out << "\n\n";
});
}
//===----------------------------------------------------------------------===//
// IRPrinterConfig
//===----------------------------------------------------------------------===//
/// Initialize the configuration.
/// * 'printModuleScope' signals if the module IR should be printed, even
/// for non module passes.
PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope)
: printModuleScope(printModuleScope) {}
PassManager::IRPrinterConfig::~IRPrinterConfig() {}
/// A hook that may be overridden by a derived config that checks if the IR
/// of 'operation' should be dumped *before* the pass 'pass' has been
/// executed. If the IR should be dumped, 'printCallback' should be invoked
/// with the stream to dump into.
void PassManager::IRPrinterConfig::printBeforeIfEnabled(
Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
// By default, never print.
}
/// A hook that may be overridden by a derived config that checks if the IR
/// of 'operation' should be dumped *after* the pass 'pass' has been
/// executed. If the IR should be dumped, 'printCallback' should be invoked
/// with the stream to dump into.
void PassManager::IRPrinterConfig::printAfterIfEnabled(
Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
// By default, never print.
}
//===----------------------------------------------------------------------===//
// PassManager
//===----------------------------------------------------------------------===//
namespace {
/// Simple wrapper config that allows for the simpler interface defined above.
struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
BasicIRPrinterConfig(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out)
: IRPrinterConfig(printModuleScope),
shouldPrintBeforePass(shouldPrintBeforePass),
shouldPrintAfterPass(shouldPrintAfterPass), out(out) {
assert((shouldPrintBeforePass || shouldPrintAfterPass) &&
"expected at least one valid filter function");
}
void printBeforeIfEnabled(Pass *pass, Operation *operation,
PrintCallbackFn printCallback) final {
if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
printCallback(out);
}
void printAfterIfEnabled(Pass *pass, Operation *operation,
PrintCallbackFn printCallback) final {
if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
printCallback(out);
}
/// Filter functions for before and after pass execution.
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
/// The stream to output to.
raw_ostream &out;
};
} // end anonymous namespace
/// Add an instrumentation to print the IR before and after pass execution,
/// using the provided configuration.
void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
addInstrumentation(
std::make_unique<IRPrinterInstrumentation>(std::move(config)));
}
/// Add an instrumentation to print the IR before and after pass execution.
void PassManager::enableIRPrinting(
std::function<bool(Pass *)> shouldPrintBeforePass,
std::function<bool(Pass *)> shouldPrintAfterPass, bool printModuleScope,
raw_ostream &out) {
addInstrumentation(std::make_unique<IRPrinterInstrumentation>(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out) {
enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
printModuleScope, out));
}

View File

@ -104,16 +104,17 @@ static llvm::ManagedStatic<llvm::Optional<PassManagerOptions>> options;
/// Add an IR printing instrumentation if enabled by any 'print-ir' flags.
void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
std::function<bool(Pass *)> shouldPrintBeforePass, shouldPrintAfterPass;
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
// Handle print-before.
if (printBeforeAll) {
// If we are printing before all, then just return true for the filter.
shouldPrintBeforePass = [](Pass *) { return true; };
shouldPrintBeforePass = [](Pass *, Operation *) { return true; };
} else if (printBefore.hasAnyOccurrences()) {
// Otherwise if there are specific passes to print before, then check to see
// if the pass info for the current pass is included in the list.
shouldPrintBeforePass = [&](Pass *pass) {
shouldPrintBeforePass = [&](Pass *pass, Operation *) {
auto *passInfo = pass->lookupPassInfo();
return passInfo && printBefore.contains(passInfo);
};
@ -122,11 +123,11 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
// Handle print-after.
if (printAfterAll) {
// If we are printing after all, then just return true for the filter.
shouldPrintAfterPass = [](Pass *) { return true; };
shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
} else if (printAfter.hasAnyOccurrences()) {
// Otherwise if there are specific passes to print after, then check to see
// if the pass info for the current pass is included in the list.
shouldPrintAfterPass = [&](Pass *pass) {
shouldPrintAfterPass = [&](Pass *pass, Operation *) {
auto *passInfo = pass->lookupPassInfo();
return passInfo && printAfter.contains(passInfo);
};