Add a hook to the OpAsmDialectInterface to allow providing a special name for the operation result.

This generalizes the current special handling for constant operations(they get named 'cst'/'true'/'false'/etc.)

PiperOrigin-RevId: 264723379
This commit is contained in:
River Riddle 2019-08-21 16:50:30 -07:00 committed by A. Unique TensorFlower
parent d661eda811
commit c400c9a1ec
3 changed files with 57 additions and 24 deletions

View File

@ -536,7 +536,7 @@ private:
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
using Base::Base;
OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
/// Hooks for getting identifier aliases for symbols. The identifier is used
/// in place of the symbol when printing textual IR.
@ -553,6 +553,10 @@ public:
/// Hook for defining Type aliases.
virtual void
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) const {}
/// Get a special name to use when printing the given operation. The desired
/// name should be streamed into 'os'.
virtual void getOpResultName(Operation *op, raw_ostream &os) const {}
};
} // end namespace mlir

View File

@ -34,6 +34,42 @@
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// StandardOpsDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct StdOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
/// Get a special name to use when printing the given operation. The desired
/// name should be streamed into 'os'.
void getOpResultName(Operation *op, raw_ostream &os) const final {
if (ConstantOp constant = dyn_cast<ConstantOp>(op))
return getConstantOpResultName(constant, os);
}
/// Get a special name to use when printing the given constant.
static void getConstantOpResultName(ConstantOp op, raw_ostream &os) {
Type type = op.getType();
Attribute value = op.getValue();
if (auto intCst = value.dyn_cast<IntegerAttr>()) {
if (type.isIndex()) {
os << 'c' << intCst.getInt();
} else if (type.cast<IntegerType>().isInteger(1)) {
// i1 constants get special names.
os << (intCst.getInt() ? "true" : "false");
} else {
os << 'c' << intCst.getInt() << '_' << type;
}
} else if (type.isa<FunctionType>()) {
os << 'f';
} else {
os << "cst";
}
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// StandardOpsDialect
//===----------------------------------------------------------------------===//
@ -86,6 +122,7 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
#define GET_OP_LIST
#include "mlir/Dialect/StandardOps/Ops.cpp.inc"
>();
addInterfaces<StdOpAsmInterface>();
}
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,

View File

@ -27,7 +27,6 @@
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
@ -135,6 +134,12 @@ public:
}
}
/// Get an instance of the OpAsmDialectInterface for the given dialect, or
/// null if one wasn't registered.
const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
return interfaces.getInterfaceFor(dialect);
}
private:
void recordAttributeReference(Attribute attr) {
// Don't recheck attributes that have already been seen or those that
@ -1364,26 +1369,11 @@ void OperationPrinter::numberValueID(Value *value) {
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
// Give constant integers special names.
if (auto *op = value->getDefiningOp()) {
Attribute cst;
if (m_Constant(&cst).match(op)) {
Type type = op->getResult(0)->getType();
if (auto intCst = cst.dyn_cast<IntegerAttr>()) {
if (type.isIndex()) {
specialName << 'c' << intCst.getInt();
} else if (type.cast<IntegerType>().isInteger(1)) {
// i1 constants get special names.
specialName << (intCst.getInt() ? "true" : "false");
} else {
specialName << 'c' << intCst.getInt() << '_' << type;
}
} else if (type.isa<FunctionType>()) {
specialName << 'f';
} else {
specialName << "cst";
}
}
// Check to see if this value requested a special name.
auto *op = value->getDefiningOp();
if (state && op) {
if (auto *interface = state->getOpAsmInterface(op->getDialect()))
interface->getOpResultName(op, specialName);
}
if (specialNameBuffer.empty()) {
@ -1717,7 +1707,8 @@ void Operation::print(raw_ostream &os) {
while (auto *nextRegion = region->getParentRegion())
region = nextRegion;
ModulePrinter modulePrinter(os);
ModuleState state(getContext());
ModulePrinter modulePrinter(os, &state);
OperationPrinter(region, modulePrinter).print(this);
}
@ -1737,7 +1728,8 @@ void Block::print(raw_ostream &os) {
while (auto *nextRegion = region->getParentRegion())
region = nextRegion;
ModulePrinter modulePrinter(os);
ModuleState state(region->getContext());
ModulePrinter modulePrinter(os, &state);
OperationPrinter(region, modulePrinter).print(this);
}