forked from OSchip/llvm-project
Add a hook to the OpAsmDialectInterface to allow providing a special name for the operation result.
This generalizes the current special handling for constant operations(they get named 'cst'/'true'/'false'/etc.) PiperOrigin-RevId: 264723379
This commit is contained in:
parent
d661eda811
commit
c400c9a1ec
|
@ -536,7 +536,7 @@ private:
|
|||
class OpAsmDialectInterface
|
||||
: public DialectInterface::Base<OpAsmDialectInterface> {
|
||||
public:
|
||||
using Base::Base;
|
||||
OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
/// Hooks for getting identifier aliases for symbols. The identifier is used
|
||||
/// in place of the symbol when printing textual IR.
|
||||
|
@ -553,6 +553,10 @@ public:
|
|||
/// Hook for defining Type aliases.
|
||||
virtual void
|
||||
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) const {}
|
||||
|
||||
/// Get a special name to use when printing the given operation. The desired
|
||||
/// name should be streamed into 'os'.
|
||||
virtual void getOpResultName(Operation *op, raw_ostream &os) const {}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -34,6 +34,42 @@
|
|||
#include "llvm/Support/raw_ostream.h"
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StandardOpsDialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace {
|
||||
struct StdOpAsmInterface : public OpAsmDialectInterface {
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
/// Get a special name to use when printing the given operation. The desired
|
||||
/// name should be streamed into 'os'.
|
||||
void getOpResultName(Operation *op, raw_ostream &os) const final {
|
||||
if (ConstantOp constant = dyn_cast<ConstantOp>(op))
|
||||
return getConstantOpResultName(constant, os);
|
||||
}
|
||||
|
||||
/// Get a special name to use when printing the given constant.
|
||||
static void getConstantOpResultName(ConstantOp op, raw_ostream &os) {
|
||||
Type type = op.getType();
|
||||
Attribute value = op.getValue();
|
||||
if (auto intCst = value.dyn_cast<IntegerAttr>()) {
|
||||
if (type.isIndex()) {
|
||||
os << 'c' << intCst.getInt();
|
||||
} else if (type.cast<IntegerType>().isInteger(1)) {
|
||||
// i1 constants get special names.
|
||||
os << (intCst.getInt() ? "true" : "false");
|
||||
} else {
|
||||
os << 'c' << intCst.getInt() << '_' << type;
|
||||
}
|
||||
} else if (type.isa<FunctionType>()) {
|
||||
os << 'f';
|
||||
} else {
|
||||
os << "cst";
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StandardOpsDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -86,6 +122,7 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
|||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/StandardOps/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<StdOpAsmInterface>();
|
||||
}
|
||||
|
||||
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
@ -135,6 +134,12 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/// Get an instance of the OpAsmDialectInterface for the given dialect, or
|
||||
/// null if one wasn't registered.
|
||||
const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
|
||||
return interfaces.getInterfaceFor(dialect);
|
||||
}
|
||||
|
||||
private:
|
||||
void recordAttributeReference(Attribute attr) {
|
||||
// Don't recheck attributes that have already been seen or those that
|
||||
|
@ -1364,26 +1369,11 @@ void OperationPrinter::numberValueID(Value *value) {
|
|||
SmallString<32> specialNameBuffer;
|
||||
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
||||
|
||||
// Give constant integers special names.
|
||||
if (auto *op = value->getDefiningOp()) {
|
||||
Attribute cst;
|
||||
if (m_Constant(&cst).match(op)) {
|
||||
Type type = op->getResult(0)->getType();
|
||||
if (auto intCst = cst.dyn_cast<IntegerAttr>()) {
|
||||
if (type.isIndex()) {
|
||||
specialName << 'c' << intCst.getInt();
|
||||
} else if (type.cast<IntegerType>().isInteger(1)) {
|
||||
// i1 constants get special names.
|
||||
specialName << (intCst.getInt() ? "true" : "false");
|
||||
} else {
|
||||
specialName << 'c' << intCst.getInt() << '_' << type;
|
||||
}
|
||||
} else if (type.isa<FunctionType>()) {
|
||||
specialName << 'f';
|
||||
} else {
|
||||
specialName << "cst";
|
||||
}
|
||||
}
|
||||
// Check to see if this value requested a special name.
|
||||
auto *op = value->getDefiningOp();
|
||||
if (state && op) {
|
||||
if (auto *interface = state->getOpAsmInterface(op->getDialect()))
|
||||
interface->getOpResultName(op, specialName);
|
||||
}
|
||||
|
||||
if (specialNameBuffer.empty()) {
|
||||
|
@ -1717,7 +1707,8 @@ void Operation::print(raw_ostream &os) {
|
|||
while (auto *nextRegion = region->getParentRegion())
|
||||
region = nextRegion;
|
||||
|
||||
ModulePrinter modulePrinter(os);
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter modulePrinter(os, &state);
|
||||
OperationPrinter(region, modulePrinter).print(this);
|
||||
}
|
||||
|
||||
|
@ -1737,7 +1728,8 @@ void Block::print(raw_ostream &os) {
|
|||
while (auto *nextRegion = region->getParentRegion())
|
||||
region = nextRegion;
|
||||
|
||||
ModulePrinter modulePrinter(os);
|
||||
ModuleState state(region->getContext());
|
||||
ModulePrinter modulePrinter(os, &state);
|
||||
OperationPrinter(region, modulePrinter).print(this);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue