NFC: Cleanup the various Op::print methods.

This cleans up the implementation of the various operation print methods. This is done via a combination of code cleanup, adding new streaming methods to the printer(e.g. operand ranges), etc.

PiperOrigin-RevId: 285285181
This commit is contained in:
River Riddle 2019-12-12 15:31:39 -08:00 committed by A. Unique TensorFlower
parent a50cb184a0
commit e7aa47ff11
11 changed files with 174 additions and 280 deletions

View File

@ -140,7 +140,7 @@ def NVVM_MmaOp :
}];
let parser = [{ return parseNVVMMmaOp(parser, result); }];
let printer = [{ printNVVMMmaOp(p, *this); }];
let verifier = [{ return mlir::NVVM::verify(*this); }];
let verifier = [{ return ::verify(*this); }];
}
#endif // NVVMIR_OPS

View File

@ -154,6 +154,18 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) {
p.printOperand(&value);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value *value) {
return p << *value;
}
template <typename T,
typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
!std::is_convertible<T &, Value *>::value,
T>::type * = nullptr>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
p.printOperands(values);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
p.printType(type);
@ -170,14 +182,29 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
// FunctionType with the Type version above, not have it match this.
template <typename T, typename std::enable_if<
!std::is_convertible<T &, Value &>::value &&
!std::is_convertible<T &, Value *>::value &&
!std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value,
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
!llvm::is_one_of<T, bool>::value,
T>::type * = nullptr>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
p.getStream() << other;
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
return p << (value ? StringRef("true") : "false");
}
template <typename IteratorT>
inline OpAsmPrinter &
operator<<(OpAsmPrinter &p,
const iterator_range<ValueTypeIterator<IteratorT>> &types) {
interleaveComma(types, p);
return p;
}
//===----------------------------------------------------------------------===//
// OpAsmParser
//===----------------------------------------------------------------------===//

View File

@ -1985,18 +1985,12 @@ static ParseResult parseAffineMinOp(OpAsmParser &parser,
static void print(OpAsmPrinter &p, AffineMinOp op) {
p << op.getOperationName() << ' '
<< op.getAttr(AffineMinOp::getMapAttrName());
auto begin = op.operand_begin();
auto end = op.operand_end();
auto operands = op.getOperands();
unsigned numDims = op.map().getNumDims();
p << '(';
p.printOperands(begin, begin + numDims);
p << ')';
p << '(' << operands.take_front(numDims) << ')';
if (begin + numDims != end) {
p << '[';
p.printOperands(begin + numDims, end);
p << ']';
}
if (operands.size() != numDims)
p << '[' << operands.drop_front(numDims) << ']';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
}

View File

@ -289,8 +289,7 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
// Print the launch configuration.
p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword();
printSizeAssignment(p, op.getGridSize(),
operands.drop_back(operands.size() - 3),
printSizeAssignment(p, op.getGridSize(), operands.take_front(3),
op.getBlockIds());
p << ' ' << op.getThreadsKeyword();
printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3),
@ -303,25 +302,17 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
// Print the data argument remapping.
if (!op.body().empty() && !operands.empty()) {
p << ' ' << op.getArgsKeyword() << '(';
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << *op.body().front().getArgument(LaunchOp::kNumConfigRegionAttributes +
i)
Block *entryBlock = &op.body().front();
interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
p << *entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
<< " = " << *operands[i];
}
});
p << ") ";
}
// Print the types of data arguments.
if (!operands.empty()) {
p << ": ";
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << operands[i]->getType();
}
}
if (!operands.empty())
p << ": " << operands.getTypes();
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs());
@ -701,7 +692,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
return;
p << ' ' << keyword << '(';
interleaveComma(values, p.getStream(),
interleaveComma(values, p,
[&p](BlockArgument *v) { p << *v << " : " << v->getType(); });
p << ')';
}

View File

@ -177,9 +177,8 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
SmallVector<Type, 8> types(op.getOperandTypes());
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
p << op.getOperationName() << ' ' << *op.base() << '[';
p.printOperands(std::next(op.operand_begin()), op.operand_end());
p << ']';
p << op.getOperationName() << ' ' << *op.base() << '['
<< op.getOperands().drop_front() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << funcTy;
}
@ -312,10 +311,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
else
p << *op.getOperand(0);
p << '(';
p.printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
p << ')';
p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
// Reconstruct the function MLIR function type from operand and result types.
@ -938,8 +934,7 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
// Print the trailing type unless it's a string global.
if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
return;
p << " : ";
p.printType(op.type());
p << " : " << op.type();
Region &initializer = op.getInitializerRegion();
if (!initializer.empty())
@ -1346,8 +1341,7 @@ static LogicalResult verify(LLVMFuncOp op) {
static void printNullOp(OpAsmPrinter &p, LLVM::NullOp op) {
p << NullOp::getOperationName();
p.printOptionalAttrDict(op.getAttrs());
p << " : ";
p.printType(op.getType());
p << " : " << op.getType();
}
// <operation> = `llvm.mlir.null` : type

View File

@ -37,18 +37,17 @@
#include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h"
namespace mlir {
namespace NVVM {
using namespace mlir;
using namespace NVVM;
//===----------------------------------------------------------------------===//
// Printing/parsing for NVVM ops
//===----------------------------------------------------------------------===//
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << op->getName() << " ";
p.printOperands(op->getOperands());
p << op->getName() << " " << op->getOperands();
if (op->getNumResults() > 0)
interleaveComma(op->getResultTypes(), p << " : ");
p << " : " << op->getResultTypes();
}
// <operation> ::= `llvm.nvvm.XYZ` : type
@ -141,8 +140,7 @@ static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
}
static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
p << op.getOperationName() << " ";
p.printOperands(op.getOperands());
p << op.getOperationName() << " " << op.getOperands();
p.printOptionalAttrDict(op.getAttrs());
p << " : "
<< FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
@ -210,10 +208,11 @@ NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
allowUnknownOperations();
}
namespace mlir {
namespace NVVM {
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
static DialectRegistration<NVVMDialect> nvvmDialect;
} // namespace NVVM
} // namespace mlir
static DialectRegistration<NVVMDialect> nvvmDialect;

View File

@ -36,18 +36,17 @@
#include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h"
namespace mlir {
namespace ROCDL {
using namespace mlir;
using namespace ROCDL;
//===----------------------------------------------------------------------===//
// Printing/parsing for ROCDL ops
//===----------------------------------------------------------------------===//
static void printROCDLOp(OpAsmPrinter &p, Operation *op) {
p << op->getName() << " ";
p.printOperands(op->getOperands());
p << op->getName() << " " << op->getOperands();
if (op->getNumResults() > 0)
interleaveComma(op->getResultTypes(), p << " : ");
p << " : " << op->getResultTypes();
}
// <operation> ::= `rocdl.XYZ` : type
@ -73,10 +72,11 @@ ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) {
allowUnknownOperations();
}
namespace mlir {
namespace ROCDL {
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
static DialectRegistration<ROCDLDialect> rocdlDialect;
} // namespace ROCDL
} // namespace mlir
static DialectRegistration<ROCDLDialect> rocdlDialect;

View File

@ -60,18 +60,16 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
llvm::StringSet<> linalgTraitAttrsSet;
linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs;
for (auto attr : op.getAttrs()) {
for (auto attr : op.getAttrs())
if (linalgTraitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
}
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " ";
p.printOperands(op.getOperands());
p << op.getOperationName() << " " << dictAttr << " " << op.getOperands();
if (!op.region().empty())
p.printRegion(op.region());
p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << ": ";
interleaveComma(op.getOperandTypes(), p);
p << ": " << op.getOperandTypes();
}
static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
@ -342,14 +340,13 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
}
static void print(OpAsmPrinter &p, SliceOp op) {
p << SliceOp::getOperationName() << " " << *op.view() << "[";
p.printOperands(op.indexings());
p << "] ";
auto indexings = op.indexings();
p << SliceOp::getOperationName() << " " << *op.view() << "[" << indexings
<< "] ";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getBaseViewType();
for (auto indexing : op.indexings()) {
p << ", " << indexing->getType();
}
if (!indexings.empty())
p << ", " << op.indexings().getTypes();
p << ", " << op.getType();
}
@ -455,16 +452,11 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
static void print(OpAsmPrinter &p, YieldOp op) {
p << op.getOperationName();
if (op.getNumOperands() > 0) {
p << ' ';
p.printOperands(op.operand_begin(), op.operand_end());
}
if (op.getNumOperands() > 0)
p << ' ' << op.getOperands();
p.printOptionalAttrDict(op.getAttrs());
if (op.getNumOperands() > 0) {
p << " : ";
interleaveComma(op.getOperands(), p,
[&](Value *e) { p.printType(e->getType()); });
}
if (op.getNumOperands() > 0)
p << " : " << op.getOperandTypes();
}
static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
@ -536,12 +528,9 @@ static LogicalResult verify(YieldOp op) {
// Where %0, %1 and %2 are ssa-values of type MemRefType with strides.
static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
p << op->getName().getStringRef() << "(";
interleaveComma(op->getOperands(), p, [&](Value *v) { p << *v; });
p << ")";
p << op->getName().getStringRef() << "(" << op->getOperands() << ")";
p.printOptionalAttrDict(op->getAttrs());
p << " : ";
interleaveComma(op->getOperands(), p, [&](Value *v) { p << v->getType(); });
p << " : " << op->getOperandTypes();
}
static ParseResult parseLinalgLibraryOp(OpAsmParser &parser,

View File

@ -500,11 +500,8 @@ static ParseResult parseBitFieldExtractOp(OpAsmParser &parser,
}
static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) {
printer << op->getName() << ' ';
printer.printOperands(op->getOperands());
printer << " : " << op->getOperand(0)->getType() << ", "
<< op->getOperand(1)->getType() << ", "
<< op->getOperand(2)->getType();
printer << op->getName() << ' ' << op->getOperands() << " : "
<< op->getOperandTypes();
}
static LogicalResult verifyBitFieldExtractOp(Operation *op) {
@ -580,9 +577,8 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
}
static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
printer << logicalOp->getName() << ' ';
printer.printOperands(logicalOp->getOperands());
printer << " : " << logicalOp->getOperand(0)->getType();
printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
<< logicalOp->getOperand(0)->getType();
}
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
@ -717,9 +713,7 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
<< '[';
printer.printOperands(op.indices());
printer << "] : " << op.base_ptr()->getType();
<< '[' << op.indices() << "] : " << op.base_ptr()->getType();
}
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
@ -875,9 +869,8 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
<< stringifyScope(atomOp.memory_scope()) << "\" \""
<< stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
<< stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" ";
printer.printOperands(atomOp.getOperands());
printer << " : " << atomOp.pointer()->getType();
<< stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
<< atomOp.getOperands() << " : " << atomOp.pointer()->getType();
}
static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
@ -975,9 +968,9 @@ static ParseResult parseBitFieldInsertOp(OpAsmParser &parser,
static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
OpAsmPrinter &printer) {
printer << spirv::BitFieldInsertOp::getOperationName() << ' ';
printer.printOperands(bitFieldInsertOp.getOperands());
printer << " : " << bitFieldInsertOp.base()->getType() << ", "
printer << spirv::BitFieldInsertOp::getOperationName() << ' '
<< bitFieldInsertOp.getOperands() << " : "
<< bitFieldInsertOp.base()->getType() << ", "
<< bitFieldInsertOp.offset()->getType() << ", "
<< bitFieldInsertOp.count()->getType();
}
@ -1072,8 +1065,8 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
}
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
printer << spirv::BranchConditionalOp::getOperationName() << ' ';
printer.printOperand(branchOp.condition());
printer << spirv::BranchConditionalOp::getOperationName() << ' '
<< branchOp.condition();
if (auto weights = branchOp.branch_weights()) {
printer << " [";
@ -1148,9 +1141,9 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
static void print(spirv::CompositeConstructOp compositeConstructOp,
OpAsmPrinter &printer) {
printer << spirv::CompositeConstructOp::getOperationName() << " ";
printer.printOperands(compositeConstructOp.constituents());
printer << " : " << compositeConstructOp.getResult()->getType();
printer << spirv::CompositeConstructOp::getOperationName() << " "
<< compositeConstructOp.constituents() << " : "
<< compositeConstructOp.getResult()->getType();
}
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
@ -1322,9 +1315,8 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
if (constOp.getType().isa<spirv::ArrayType>()) {
if (constOp.getType().isa<spirv::ArrayType>())
printer << " : " << constOp.getType();
}
}
static LogicalResult verify(spirv::ConstantOp constOp) {
@ -1577,9 +1569,8 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
<< execModeOp.fn() << " \""
<< stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
auto values = execModeOp.values();
if (!values.size()) {
if (!values.size())
return;
}
printer << ", ";
interleaveComma(values, printer, [&](Attribute a) {
printer << a.cast<IntegerAttr>().getInt();
@ -1626,9 +1617,8 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
printer << spirv::FunctionCallOp::getOperationName() << ' '
<< functionCallOp.getAttr(kCallee) << '(';
printer.printOperands(functionCallOp.arguments());
printer << ") : " << functionType;
<< functionCallOp.getAttr(kCallee) << '('
<< functionCallOp.arguments() << ") : " << functionType;
}
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
@ -1829,9 +1819,8 @@ static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
static void print(spirv::GroupNonUniformBallotOp ballotOp,
OpAsmPrinter &printer) {
printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
<< stringifyScope(ballotOp.execution_scope()) << "\" ";
printer.printOperand(ballotOp.predicate());
printer << " : " << ballotOp.getType();
<< stringifyScope(ballotOp.execution_scope()) << "\" "
<< ballotOp.predicate() << " : " << ballotOp.getType();
}
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
@ -1943,9 +1932,8 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
// Print the pointer operand.
printer.printOperand(loadOp.ptr());
printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
<< loadOp.ptr();
printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
@ -2238,26 +2226,26 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
auto *op = moduleOp.getOperation();
printer << spirv::ModuleOp::getOperationName();
// Only print out addressing model and memory model in a nicer way if both
// presents. Otherwise, print them in the general form. This helps debugging
// ill-formed ModuleOp.
// presents. Otherwise, print them in the general form. This helps
// debugging ill-formed ModuleOp.
SmallVector<StringRef, 2> elidedAttrs;
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
if (op->getAttr(addressingModelAttrName) &&
op->getAttr(memoryModelAttrName)) {
printer << spirv::ModuleOp::getOperationName() << " \""
if (moduleOp.getAttr(addressingModelAttrName) &&
moduleOp.getAttr(memoryModelAttrName)) {
printer << " \""
<< spirv::stringifyAddressingModel(moduleOp.addressing_model())
<< "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
<< '"';
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
}
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
}
static LogicalResult verify(spirv::ModuleOp moduleOp) {
@ -2417,9 +2405,8 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser,
}
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
printer << spirv::ReturnValueOp::getOperationName() << ' ';
printer.printOperand(retValOp.value());
printer << " : " << retValOp.value()->getType();
printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
<< " : " << retValOp.value()->getType();
}
static LogicalResult verify(spirv::ReturnValueOp retValOp) {
@ -2471,13 +2458,8 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
printer << spirv::SelectOp::getOperationName() << " ";
// Print the operands.
printer.printOperands(op.getOperands());
// Print colon and types.
printer << " : " << op.condition()->getType() << ", "
printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
<< " : " << op.condition()->getType() << ", "
<< op.result()->getType();
}
@ -2788,8 +2770,7 @@ static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
printer.printSymbolName(constOp.sym_name());
if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
printer << " = ";
printer.printAttribute(constOp.default_value());
printer << " = " << constOp.default_value();
}
static LogicalResult verify(spirv::SpecConstantOp constOp) {
@ -2844,17 +2825,12 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
// Print the pointer operand
printer.printOperand(storeOp.ptr());
printer << ", ";
// Print the value operand
printer.printOperand(storeOp.value());
printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
<< storeOp.ptr() << ", " << storeOp.value();
printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
printer << " : " << storeOp.value()->getType();
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
@ -2885,9 +2861,8 @@ static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
}
static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ';
printer.printOperand(ballotOp.predicate());
printer << " : " << ballotOp.getType();
printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '
<< ballotOp.predicate() << " : " << ballotOp.getType();
}
//===----------------------------------------------------------------------===//
@ -2973,20 +2948,15 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
auto *op = varOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs{
spirv::attributeName<spirv::StorageClass>()};
printer << spirv::VariableOp::getOperationName();
// Print optional initializer
if (op->getNumOperands() > 0) {
printer << " init(";
printer.printOperands(varOp.initializer());
printer << ")";
}
printVariableDecorations(op, printer, elidedAttrs);
if (varOp.getNumOperands() != 0)
printer << " init(" << varOp.initializer() << ")";
printVariableDecorations(varOp, printer, elidedAttrs);
printer << " : " << varOp.getType();
}

View File

@ -166,15 +166,10 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end,
unsigned numDims, OpAsmPrinter &p) {
p << '(';
p.printOperands(begin, begin + numDims);
p << ')';
if (begin + numDims != end) {
p << '[';
p.printOperands(begin + numDims, end);
p << ']';
}
Operation::operand_range operands(begin, end);
p << '(' << operands.take_front(numDims) << ')';
if (operands.size() != numDims)
p << '[' << operands.drop_front(numDims) << ']';
}
// Parses dimension and symbol list, and sets 'numDims' to the number of
@ -485,12 +480,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, CallOp op) {
p << "call " << op.getAttr("callee") << '(';
p.printOperands(op.getOperands());
p << ')';
p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
p << " : ";
p.printType(op.getCalleeType());
p << " : " << op.getCalleeType();
}
static LogicalResult verify(CallOp op) {
@ -572,11 +564,7 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, CallIndirectOp op) {
p << "call_indirect ";
p.printOperand(op.getCallee());
p << '(';
p.printOperands(op.getArgOperands());
p << ')';
p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
p << " : " << op.getCallee()->getType();
}
@ -690,12 +678,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
auto predicateValue =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
<< '"';
p << ", ";
p.printOperand(op.lhs());
p << ", ";
p.printOperand(op.rhs());
<< '"' << ", " << op.lhs() << ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
p << " : " << op.lhs()->getType();
@ -851,15 +834,8 @@ static void print(OpAsmPrinter &p, CmpFOp op) {
assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
"unknown predicate index");
Builder b(op.getContext());
auto predicateStringAttr =
b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
p.printAttribute(predicateStringAttr);
p << ", ";
p.printOperand(op.lhs());
p << ", ";
p.printOperand(op.rhs());
p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
<< ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
p << " : " << op.lhs()->getType();
@ -1002,9 +978,7 @@ static ParseResult parseCondBranchOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, CondBranchOp op) {
p << "cond_br ";
p.printOperand(op.getCondition());
p << ", ";
p << "cond_br " << op.getCondition() << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
p << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
@ -1025,7 +999,7 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
if (op.getAttrs().size() > 1)
p << ' ';
p.printAttribute(op.getValue());
p << op.getValue();
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>())
@ -1407,18 +1381,12 @@ void DmaStartOp::build(Builder *builder, OperationState &result,
}
void DmaStartOp::print(OpAsmPrinter &p) {
p << "dma_start " << *getSrcMemRef() << '[';
p.printOperands(getSrcIndices());
p << "], " << *getDstMemRef() << '[';
p.printOperands(getDstIndices());
p << "], " << *getNumElements();
p << ", " << *getTagMemRef() << '[';
p.printOperands(getTagIndices());
p << ']';
if (isStrided()) {
p << ", " << *getStride();
p << ", " << *getNumElementsPerStride();
}
p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], "
<< *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements()
<< ", " << *getTagMemRef() << '[' << getTagIndices() << ']';
if (isStrided())
p << ", " << *getStride() << ", " << *getNumElementsPerStride();
p.printOptionalAttrDict(getAttrs());
p << " : " << getSrcMemRef()->getType();
p << ", " << getDstMemRef()->getType();
@ -1550,12 +1518,8 @@ void DmaWaitOp::build(Builder *builder, OperationState &result,
}
void DmaWaitOp::print(OpAsmPrinter &p) {
p << "dma_wait ";
p.printOperand(getTagMemRef());
p << '[';
p.printOperands(getTagIndices());
p << "], ";
p.printOperand(getNumElements());
p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
<< getNumElements();
p.printOptionalAttrDict(getAttrs());
p << " : " << getTagMemRef()->getType();
}
@ -1604,8 +1568,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ExtractElementOp op) {
p << "extract_element " << *op.getAggregate() << '[';
p.printOperands(op.getIndices());
p << "extract_element " << *op.getAggregate() << '[' << op.getIndices();
p << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getAggregate()->getType();
@ -1686,9 +1649,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, LoadOp op) {
p << "load " << *op.getMemRef() << '[';
p.printOperands(op.getIndices());
p << ']';
p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
@ -1922,12 +1883,8 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, ReturnOp op) {
p << "return";
if (op.getNumOperands() != 0) {
p << ' ';
p.printOperands(op.getOperands());
p << " : ";
interleaveComma(op.getOperandTypes(), p);
}
if (op.getNumOperands() != 0)
p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
}
static LogicalResult verify(ReturnOp op) {
@ -1984,9 +1941,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, SelectOp op) {
p << "select ";
p.printOperands(op.getOperands());
p << " : " << op.getTrueValue()->getType();
p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType();
p.printOptionalAttrDict(op.getAttrs());
}
@ -2093,9 +2048,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
static void print(OpAsmPrinter &p, StoreOp op) {
p << "store " << *op.getValueToStore();
p << ", " << *op.getMemRef() << '[';
p.printOperands(op.getIndices());
p << ']';
p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
@ -2339,9 +2292,7 @@ static void print(OpAsmPrinter &p, ViewOp op) {
auto *dynamicOffset = op.getDynamicOffset();
if (dynamicOffset != nullptr)
p.printOperand(dynamicOffset);
p << "][";
p.printOperands(op.getDynamicSizes());
p << ']';
p << "][" << op.getDynamicSizes() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
}
@ -2609,13 +2560,8 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, SubViewOp op) {
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
p.printOperands(op.offsets());
p << "][";
p.printOperands(op.sizes());
p << "][";
p.printOperands(op.strides());
p << ']';
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets()
<< "][" << op.sizes() << "][" << op.strides() << ']';
SmallVector<StringRef, 1> elidedAttrs = {
SubViewOp::getOperandSegmentSizeAttr()};

View File

@ -110,17 +110,16 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
llvm::StringSet<> traitAttrsSet;
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs;
for (auto attr : op.getAttrs()) {
for (auto attr : op.getAttrs())
if (traitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
}
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
p << *op.rhs() << ", " << *op.acc();
if (llvm::size(op.masks()) == 2) {
p << ", " << **op.masks().begin();
p << ", " << **(op.masks().begin() + 1);
}
if (op.masks().size() == 2)
p << ", " << op.masks();
p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
<< op.getResultType();
@ -417,9 +416,8 @@ static LogicalResult verify(vector::ExtractOp op) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BroadcastOp op) {
p << op.getOperationName() << " " << *op.source();
p << " : " << op.getSourceType();
p << " to " << op.getVectorType();
p << op.getOperationName() << " " << *op.source() << " : "
<< op.getSourceType() << " to " << op.getVectorType();
}
static LogicalResult verify(BroadcastOp op) {
@ -560,8 +558,7 @@ static void print(OpAsmPrinter &p, InsertOp op) {
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
<< op.position();
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
p << " : " << op.getSourceType();
p << " into " << op.getDestVectorType();
p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
}
static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
@ -789,8 +786,8 @@ static LogicalResult verify(InsertStridedSliceOp op) {
static void print(OpAsmPrinter &p, OuterProductOp op) {
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
if (llvm::size(op.acc()) > 0)
p << ", " << **op.acc().begin();
if (!op.acc().empty())
p << ", " << op.acc();
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
}
@ -1034,16 +1031,10 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
}
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " ";
p.printOperand(op.memref());
p << "[";
p.printOperands(op.indices());
p << "], ";
p.printOperand(op.padding());
p << " ";
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
<< "], " << op.padding() << " ";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
p << ", " << op.getVectorType();
p << " : " << op.getMemRefType() << ", " << op.getVectorType();
}
ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) {
@ -1106,15 +1097,10 @@ static LogicalResult verify(TransferReadOp op) {
// TransferWriteOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TransferWriteOp op) {
p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref();
p << "[";
p.printOperands(op.indices());
p << "]";
p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref()
<< "[" << op.indices() << "]";
p.printOptionalAttrDict(op.getAttrs());
p << " : ";
p.printType(op.getVectorType());
p << ", ";
p.printType(op.getMemRefType());
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
}
ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) {
@ -1180,13 +1166,13 @@ void TypeCastOp::build(Builder *builder, OperationState &result,
inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
}
static void print(OpAsmPrinter &p, TypeCastOp &op) {
static void print(OpAsmPrinter &p, TypeCastOp op) {
auto type = op.getOperand()->getType().cast<MemRefType>();
p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
<< inferVectorTypeCastResultType(type);
}
static LogicalResult verify(TypeCastOp &op) {
static LogicalResult verify(TypeCastOp op) {
auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
if (op.getResultMemRefType() != resultType)
return op.emitOpError("expects result type to be: ") << resultType;
@ -1208,9 +1194,9 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(resultType, result.types));
}
static void print(OpAsmPrinter &p, ConstantMaskOp &op) {
p << op.getOperationName() << ' ' << op.mask_dim_sizes();
p << " : " << op.getResult()->getType();
static void print(OpAsmPrinter &p, ConstantMaskOp op) {
p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
<< op.getResult()->getType();
}
static LogicalResult verify(ConstantMaskOp &op) {
@ -1256,13 +1242,11 @@ ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(resultType, result.types));
}
static void print(OpAsmPrinter &p, CreateMaskOp &op) {
p << op.getOperationName() << ' ';
p.printOperands(op.operands());
p << " : " << op.getResult()->getType();
static void print(OpAsmPrinter &p, CreateMaskOp op) {
p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
}
static LogicalResult verify(CreateMaskOp &op) {
static LogicalResult verify(CreateMaskOp op) {
// Verify that an operand was specified for each result vector each dimension.
if (op.getNumOperands() !=
op.getResult()->getType().cast<VectorType>().getRank())