Change dialect `printOperation()` hook to `getOperationPrinter()`

This makes the hook return a printer if available, instead of using LogicalResult  to
indicate if a printer was available (and invoked). This allows the caller to detect that
the dialect has a printer for a given operation without actually invoking the printer.
It'll be leveraged in a future revision to move printing the op name itself under control
of the ASMPrinter.

Differential Revision: https://reviews.llvm.org/D108803
This commit is contained in:
Mehdi Amini 2021-08-28 03:02:55 +00:00
parent 6726a3d858
commit fd87963eee
5 changed files with 22 additions and 20 deletions

View File

@ -121,8 +121,8 @@ public:
/// Print an operation registered to this dialect.
/// This hook is invoked for registered operation which don't override the
/// `print()` method to define their own custom assembly.
virtual LogicalResult printOperation(Operation *op,
OpAsmPrinter &printer) const;
virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
getOperationPrinter(Operation *op) const;
//===--------------------------------------------------------------------===//
// Verification Hooks
@ -297,8 +297,7 @@ class DialectRegistry {
public:
explicit DialectRegistry();
template <typename ConcreteDialect>
void insert() {
template <typename ConcreteDialect> void insert() {
insert(TypeID::get<ConcreteDialect>(),
ConcreteDialect::getDialectNamespace(),
static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
@ -364,8 +363,7 @@ public:
/// Add an external op interface model for an op that belongs to a dialect,
/// both provided as template parameters. The dialect must be present in the
/// registry.
template <typename OpTy, typename ModelTy>
void addOpInterface() {
template <typename OpTy, typename ModelTy> void addOpInterface() {
StringRef opName = OpTy::getOperationName();
StringRef dialectName = opName.split('.').first;
addObjectInterface(dialectName, TypeID::get<OpTy>(),
@ -426,8 +424,7 @@ private:
namespace llvm {
/// Provide isa functionality for Dialects.
template <typename T>
struct isa_impl<T, ::mlir::Dialect> {
template <typename T> struct isa_impl<T, ::mlir::Dialect> {
static inline bool doit(const ::mlir::Dialect &dialect) {
return mlir::TypeID::get<T>() == dialect.getTypeID();
}

View File

@ -2508,8 +2508,10 @@ void OperationPrinter::printOperation(Operation *op) {
}
// Otherwise try to dispatch to the dialect, if available.
if (Dialect *dialect = op->getDialect()) {
if (succeeded(dialect->printOperation(op, *this)))
if (auto opPrinter = dialect->getOperationPrinter(op)) {
opPrinter(op, *this);
return;
}
}
}

View File

@ -172,11 +172,11 @@ Dialect::getParseOperationHook(StringRef opName) const {
return None;
}
LogicalResult Dialect::printOperation(Operation *op,
OpAsmPrinter &printer) const {
llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
Dialect::getOperationPrinter(Operation *op) const {
assert(op->getDialect() == this &&
"Dialect hook invoked on non-dialect owned operation");
return failure();
return nullptr;
}
/// Utility function that returns if the given string is a valid dialect

View File

@ -313,14 +313,15 @@ TestDialect::getParseOperationHook(StringRef opName) const {
return None;
}
LogicalResult TestDialect::printOperation(Operation *op,
OpAsmPrinter &printer) const {
llvm::unique_function<void(Operation *, OpAsmPrinter &)>
TestDialect::getOperationPrinter(Operation *op) const {
StringRef opName = op->getName().getStringRef();
if (opName == "test.dialect_custom_printer") {
printer.getStream() << opName << " custom_format";
return success();
return [](Operation *op, OpAsmPrinter &printer) {
printer.getStream() << op->getName().getStringRef() << " custom_format";
};
}
return failure();
return {};
}
//===----------------------------------------------------------------------===//

View File

@ -39,15 +39,17 @@ def Test_Dialect : Dialect {
void registerTypes();
::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
::mlir::Type type) const override;
::mlir::Type type) const override;
void printAttribute(::mlir::Attribute attr,
::mlir::DialectAsmPrinter &printer) const override;
// Provides a custom printing/parsing for some operations.
::llvm::Optional<ParseOpHook>
getParseOperationHook(::llvm::StringRef opName) const override;
::mlir::LogicalResult printOperation(::mlir::Operation *op,
::mlir::OpAsmPrinter &printer) const override;
::llvm::unique_function<void(::mlir::Operation *,
::mlir::OpAsmPrinter &printer)>
getOperationPrinter(::mlir::Operation *op) const override;
private:
// Storage for a custom fallback interface.
void *fallbackEffectOpInterfaces;