Add a new class, OpPrintingFlags, to enable programmatic control of Operation::print behavior.

This allows for controlling the behavior of the AsmPrinter programmatically, instead of relying exclusively on cl::opt flags. This will also allow for more fine-tuned control of printing behavior per callsite, instead of being applied globally.

PiperOrigin-RevId: 273368361
This commit is contained in:
River Riddle 2019-10-07 13:54:16 -07:00 committed by A. Unique TensorFlower
parent 9e9c3a009a
commit aeada290b8
5 changed files with 103 additions and 25 deletions

View File

@ -65,7 +65,7 @@ public:
Optional<StringRef> getName();
/// Print the this module in the custom top-level form.
void print(raw_ostream &os);
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
void dump();
//===--------------------------------------------------------------------===//

View File

@ -105,7 +105,9 @@ public:
MLIRContext *getContext() { return getOperation()->getContext(); }
/// Print the operation to the given stream.
void print(raw_ostream &os) { state->print(os); }
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
state->print(os, flags);
}
/// Dump this operation.
void dump() { state->dump(); }

View File

@ -199,7 +199,7 @@ public:
/// take O(N) where N is the number of operations within the parent block.
bool isBeforeInBlock(Operation *other);
void print(raw_ostream &os);
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
void dump();
//===--------------------------------------------------------------------===//

View File

@ -452,6 +452,40 @@ private:
}
};
} // end namespace detail
/// Set of flags used to control the behavior of the various IR print methods
/// (e.g. Operation::Print).
class OpPrintingFlags {
public:
OpPrintingFlags();
OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {}
/// Enable printing of debug information. If 'prettyForm' is set to true,
/// debug information is printed in a more readable 'pretty' form. Note: The
/// IR generated with 'prettyForm' is not parsable.
OpPrintingFlags &enableDebugInfo(bool prettyForm = false);
/// Always print operations in the generic form.
OpPrintingFlags &printGenericOpForm();
/// Return if debug information should be printed.
bool shouldPrintDebugInfo() const;
/// Return if debug information should be printed in the pretty form.
bool shouldPrintDebugInfoPrettyForm() const;
/// Return if operations should be printed in the generic form.
bool shouldPrintGenericOpForm() const;
private:
/// Print debug information.
bool printDebugInfoFlag : 1;
bool printDebugInfoPrettyFormFlag : 1;
/// Print operations in the generic form.
bool printGenericOpFormFlag : 1;
};
} // end namespace mlir
namespace llvm {

View File

@ -56,17 +56,15 @@ void OperationName::dump() const { print(llvm::errs()); }
OpAsmPrinter::~OpAsmPrinter() {}
//===----------------------------------------------------------------------===//
// ModuleState
// OpPrintingFlags
//===----------------------------------------------------------------------===//
// TODO(riverriddle) Rethink this flag when we have a pass that can remove debug
// info or when we have a system for printer flags.
static llvm::cl::opt<bool>
shouldPrintDebugInfoOpt("mlir-print-debuginfo",
llvm::cl::desc("Print debug info in MLIR output"),
llvm::cl::init(false));
printDebugInfoOpt("mlir-print-debuginfo",
llvm::cl::desc("Print debug info in MLIR output"),
llvm::cl::init(false));
static llvm::cl::opt<bool> printPrettyDebugInfo(
static llvm::cl::opt<bool> printPrettyDebugInfoOpt(
"mlir-pretty-debuginfo",
llvm::cl::desc("Print pretty debug info in MLIR output"),
llvm::cl::init(false));
@ -74,9 +72,48 @@ static llvm::cl::opt<bool> printPrettyDebugInfo(
// Use the generic op output form in the operation printer even if the custom
// form is defined.
static llvm::cl::opt<bool>
printGenericOpForm("mlir-print-op-generic",
llvm::cl::desc("Print the generic op form"),
llvm::cl::init(false), llvm::cl::Hidden);
printGenericOpFormOpt("mlir-print-op-generic",
llvm::cl::desc("Print the generic op form"),
llvm::cl::init(false), llvm::cl::Hidden);
/// Initialize the printing flags with default supplied by the cl::opts above.
OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(printDebugInfoOpt),
printDebugInfoPrettyFormFlag(printPrettyDebugInfoOpt),
printGenericOpFormFlag(printGenericOpFormOpt) {}
/// Enable printing of debug information. If 'prettyForm' is set to true,
/// debug information is printed in a more readable 'pretty' form.
OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
printDebugInfoFlag = true;
printDebugInfoPrettyFormFlag = prettyForm;
return *this;
}
/// Always print operations in the generic form.
OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
printGenericOpFormFlag = true;
return *this;
}
/// Return if debug information should be printed.
bool OpPrintingFlags::shouldPrintDebugInfo() const {
return printDebugInfoFlag;
}
/// Return if debug information should be printed in the pretty form.
bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
return printDebugInfoPrettyFormFlag;
}
/// Return if operations should be printed in the generic form.
bool OpPrintingFlags::shouldPrintGenericOpForm() const {
return printGenericOpFormFlag;
}
//===----------------------------------------------------------------------===//
// ModuleState
//===----------------------------------------------------------------------===//
namespace {
/// A special index constant used for non-kind attribute aliases.
@ -322,10 +359,12 @@ void ModuleState::initialize(Operation *op) {
namespace {
class ModulePrinter {
public:
ModulePrinter(raw_ostream &os, ModuleState *state = nullptr)
: os(os), state(state) {}
ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
ModuleState *state = nullptr)
: os(os), printerFlags(flags), state(state) {}
explicit ModulePrinter(ModulePrinter &printer)
: os(printer.os), state(printer.state) {}
: os(printer.os), printerFlags(printer.printerFlags),
state(printer.state) {}
template <typename Container, typename UnaryFunctor>
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
@ -370,6 +409,9 @@ protected:
/// The output stream for the printer.
raw_ostream &os;
/// A set of flags to control the printer's behavior.
OpPrintingFlags printerFlags;
/// An optional printer state for the module.
ModuleState *state;
};
@ -377,7 +419,7 @@ protected:
void ModulePrinter::printTrailingLocation(Location loc) {
// Check to see if we are printing debug information.
if (!shouldPrintDebugInfoOpt)
if (!printerFlags.shouldPrintDebugInfo())
return;
os << " ";
@ -499,7 +541,7 @@ static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
}
void ModulePrinter::printLocation(LocationAttr loc) {
if (printPrettyDebugInfo) {
if (printerFlags.shouldPrintDebugInfoPrettyForm()) {
printLocationInternal(loc, /*pretty=*/true);
} else {
os << "loc(";
@ -1597,7 +1639,7 @@ void OperationPrinter::printOperation(Operation *op) {
// TODO(riverriddle): FuncOp cannot be round-tripped currently, as
// FunctionType cannot be used in a TypeAttr.
if (printGenericOpForm && !isa<FuncOp>(op))
if (printerFlags.shouldPrintGenericOpForm() && !isa<FuncOp>(op))
return printGenericOp(op);
// Check to see if this is a known operation. If so, use the registered
@ -1755,10 +1797,10 @@ void Value::dump() {
llvm::errs() << "\n";
}
void Operation::print(raw_ostream &os) {
void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
// Handle top-level operations.
if (!getParent()) {
ModulePrinter modulePrinter(os);
ModulePrinter modulePrinter(os, flags);
OperationPrinter(this, modulePrinter).print(this);
return;
}
@ -1774,7 +1816,7 @@ void Operation::print(raw_ostream &os) {
region = nextRegion;
ModuleState state(getContext());
ModulePrinter modulePrinter(os, &state);
ModulePrinter modulePrinter(os, flags, &state);
OperationPrinter(region, modulePrinter).print(this);
}
@ -1795,7 +1837,7 @@ void Block::print(raw_ostream &os) {
region = nextRegion;
ModuleState state(region->getContext());
ModulePrinter modulePrinter(os, &state);
ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state);
OperationPrinter(region, modulePrinter).print(this);
}
@ -1817,10 +1859,10 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
OperationPrinter(region, modulePrinter).printBlockName(this);
}
void ModuleOp::print(raw_ostream &os) {
void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
ModuleState state(getContext());
state.initialize(*this);
ModulePrinter(os, &state).print(*this);
ModulePrinter(os, flags, &state).print(*this);
}
void ModuleOp::dump() { print(llvm::errs()); }