diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index e01957267bd9..cf09494cb2ed 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -65,7 +65,7 @@ public: Optional getName(); /// Print the this module in the custom top-level form. - void print(raw_ostream &os); + void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); void dump(); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index dd82e7b7f715..c500e7364fa4 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -105,7 +105,9 @@ public: MLIRContext *getContext() { return getOperation()->getContext(); } /// Print the operation to the given stream. - void print(raw_ostream &os) { state->print(os); } + void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) { + state->print(os, flags); + } /// Dump this operation. void dump() { state->dump(); } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 5444d6c6912a..ff23f6a1b601 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -199,7 +199,7 @@ public: /// take O(N) where N is the number of operations within the parent block. bool isBeforeInBlock(Operation *other); - void print(raw_ostream &os); + void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); void dump(); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 70d5476f9473..5567af717f7a 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -452,6 +452,40 @@ private: } }; } // end namespace detail + +/// Set of flags used to control the behavior of the various IR print methods +/// (e.g. Operation::Print). +class OpPrintingFlags { +public: + OpPrintingFlags(); + OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {} + + /// Enable printing of debug information. If 'prettyForm' is set to true, + /// debug information is printed in a more readable 'pretty' form. Note: The + /// IR generated with 'prettyForm' is not parsable. + OpPrintingFlags &enableDebugInfo(bool prettyForm = false); + + /// Always print operations in the generic form. + OpPrintingFlags &printGenericOpForm(); + + /// Return if debug information should be printed. + bool shouldPrintDebugInfo() const; + + /// Return if debug information should be printed in the pretty form. + bool shouldPrintDebugInfoPrettyForm() const; + + /// Return if operations should be printed in the generic form. + bool shouldPrintGenericOpForm() const; + +private: + /// Print debug information. + bool printDebugInfoFlag : 1; + bool printDebugInfoPrettyFormFlag : 1; + + /// Print operations in the generic form. + bool printGenericOpFormFlag : 1; +}; + } // end namespace mlir namespace llvm { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index ce79db0ebe1e..a1cd863e7bdb 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -56,17 +56,15 @@ void OperationName::dump() const { print(llvm::errs()); } OpAsmPrinter::~OpAsmPrinter() {} //===----------------------------------------------------------------------===// -// ModuleState +// OpPrintingFlags //===----------------------------------------------------------------------===// -// TODO(riverriddle) Rethink this flag when we have a pass that can remove debug -// info or when we have a system for printer flags. static llvm::cl::opt - shouldPrintDebugInfoOpt("mlir-print-debuginfo", - llvm::cl::desc("Print debug info in MLIR output"), - llvm::cl::init(false)); + printDebugInfoOpt("mlir-print-debuginfo", + llvm::cl::desc("Print debug info in MLIR output"), + llvm::cl::init(false)); -static llvm::cl::opt printPrettyDebugInfo( +static llvm::cl::opt printPrettyDebugInfoOpt( "mlir-pretty-debuginfo", llvm::cl::desc("Print pretty debug info in MLIR output"), llvm::cl::init(false)); @@ -74,9 +72,48 @@ static llvm::cl::opt printPrettyDebugInfo( // Use the generic op output form in the operation printer even if the custom // form is defined. static llvm::cl::opt - printGenericOpForm("mlir-print-op-generic", - llvm::cl::desc("Print the generic op form"), - llvm::cl::init(false), llvm::cl::Hidden); + printGenericOpFormOpt("mlir-print-op-generic", + llvm::cl::desc("Print the generic op form"), + llvm::cl::init(false), llvm::cl::Hidden); + +/// Initialize the printing flags with default supplied by the cl::opts above. +OpPrintingFlags::OpPrintingFlags() + : printDebugInfoFlag(printDebugInfoOpt), + printDebugInfoPrettyFormFlag(printPrettyDebugInfoOpt), + printGenericOpFormFlag(printGenericOpFormOpt) {} + +/// Enable printing of debug information. If 'prettyForm' is set to true, +/// debug information is printed in a more readable 'pretty' form. +OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) { + printDebugInfoFlag = true; + printDebugInfoPrettyFormFlag = prettyForm; + return *this; +} + +/// Always print operations in the generic form. +OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { + printGenericOpFormFlag = true; + return *this; +} + +/// Return if debug information should be printed. +bool OpPrintingFlags::shouldPrintDebugInfo() const { + return printDebugInfoFlag; +} + +/// Return if debug information should be printed in the pretty form. +bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { + return printDebugInfoPrettyFormFlag; +} + +/// Return if operations should be printed in the generic form. +bool OpPrintingFlags::shouldPrintGenericOpForm() const { + return printGenericOpFormFlag; +} + +//===----------------------------------------------------------------------===// +// ModuleState +//===----------------------------------------------------------------------===// namespace { /// A special index constant used for non-kind attribute aliases. @@ -322,10 +359,12 @@ void ModuleState::initialize(Operation *op) { namespace { class ModulePrinter { public: - ModulePrinter(raw_ostream &os, ModuleState *state = nullptr) - : os(os), state(state) {} + ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, + ModuleState *state = nullptr) + : os(os), printerFlags(flags), state(state) {} explicit ModulePrinter(ModulePrinter &printer) - : os(printer.os), state(printer.state) {} + : os(printer.os), printerFlags(printer.printerFlags), + state(printer.state) {} template inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { @@ -370,6 +409,9 @@ protected: /// The output stream for the printer. raw_ostream &os; + /// A set of flags to control the printer's behavior. + OpPrintingFlags printerFlags; + /// An optional printer state for the module. ModuleState *state; }; @@ -377,7 +419,7 @@ protected: void ModulePrinter::printTrailingLocation(Location loc) { // Check to see if we are printing debug information. - if (!shouldPrintDebugInfoOpt) + if (!printerFlags.shouldPrintDebugInfo()) return; os << " "; @@ -499,7 +541,7 @@ static void printFloatValue(const APFloat &apValue, raw_ostream &os) { } void ModulePrinter::printLocation(LocationAttr loc) { - if (printPrettyDebugInfo) { + if (printerFlags.shouldPrintDebugInfoPrettyForm()) { printLocationInternal(loc, /*pretty=*/true); } else { os << "loc("; @@ -1597,7 +1639,7 @@ void OperationPrinter::printOperation(Operation *op) { // TODO(riverriddle): FuncOp cannot be round-tripped currently, as // FunctionType cannot be used in a TypeAttr. - if (printGenericOpForm && !isa(op)) + if (printerFlags.shouldPrintGenericOpForm() && !isa(op)) return printGenericOp(op); // Check to see if this is a known operation. If so, use the registered @@ -1755,10 +1797,10 @@ void Value::dump() { llvm::errs() << "\n"; } -void Operation::print(raw_ostream &os) { +void Operation::print(raw_ostream &os, OpPrintingFlags flags) { // Handle top-level operations. if (!getParent()) { - ModulePrinter modulePrinter(os); + ModulePrinter modulePrinter(os, flags); OperationPrinter(this, modulePrinter).print(this); return; } @@ -1774,7 +1816,7 @@ void Operation::print(raw_ostream &os) { region = nextRegion; ModuleState state(getContext()); - ModulePrinter modulePrinter(os, &state); + ModulePrinter modulePrinter(os, flags, &state); OperationPrinter(region, modulePrinter).print(this); } @@ -1795,7 +1837,7 @@ void Block::print(raw_ostream &os) { region = nextRegion; ModuleState state(region->getContext()); - ModulePrinter modulePrinter(os, &state); + ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); OperationPrinter(region, modulePrinter).print(this); } @@ -1817,10 +1859,10 @@ void Block::printAsOperand(raw_ostream &os, bool printType) { OperationPrinter(region, modulePrinter).printBlockName(this); } -void ModuleOp::print(raw_ostream &os) { +void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { ModuleState state(getContext()); state.initialize(*this); - ModulePrinter(os, &state).print(*this); + ModulePrinter(os, flags, &state).print(*this); } void ModuleOp::dump() { print(llvm::errs()); }