NFC: Cleanup the various Op::print methods.

This cleans up the implementation of the various operation print methods. This is done via a combination of code cleanup, adding new streaming methods to the printer(e.g. operand ranges), etc.

PiperOrigin-RevId: 285285181
This commit is contained in:
River Riddle 2019-12-12 15:31:39 -08:00 committed by A. Unique TensorFlower
parent a50cb184a0
commit e7aa47ff11
11 changed files with 174 additions and 280 deletions

View File

@ -140,7 +140,7 @@ def NVVM_MmaOp :
}]; }];
let parser = [{ return parseNVVMMmaOp(parser, result); }]; let parser = [{ return parseNVVMMmaOp(parser, result); }];
let printer = [{ printNVVMMmaOp(p, *this); }]; let printer = [{ printNVVMMmaOp(p, *this); }];
let verifier = [{ return mlir::NVVM::verify(*this); }]; let verifier = [{ return ::verify(*this); }];
} }
#endif // NVVMIR_OPS #endif // NVVMIR_OPS

View File

@ -154,6 +154,18 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) {
p.printOperand(&value); p.printOperand(&value);
return p; return p;
} }
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value *value) {
return p << *value;
}
template <typename T,
typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
!std::is_convertible<T &, Value *>::value,
T>::type * = nullptr>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
p.printOperands(values);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) { inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
p.printType(type); p.printType(type);
@ -170,14 +182,29 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
// FunctionType with the Type version above, not have it match this. // FunctionType with the Type version above, not have it match this.
template <typename T, typename std::enable_if< template <typename T, typename std::enable_if<
!std::is_convertible<T &, Value &>::value && !std::is_convertible<T &, Value &>::value &&
!std::is_convertible<T &, Value *>::value &&
!std::is_convertible<T &, Type &>::value && !std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value, !std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
!llvm::is_one_of<T, bool>::value,
T>::type * = nullptr> T>::type * = nullptr>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) { inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
p.getStream() << other; p.getStream() << other;
return p; return p;
} }
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
return p << (value ? StringRef("true") : "false");
}
template <typename IteratorT>
inline OpAsmPrinter &
operator<<(OpAsmPrinter &p,
const iterator_range<ValueTypeIterator<IteratorT>> &types) {
interleaveComma(types, p);
return p;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// OpAsmParser // OpAsmParser
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1985,18 +1985,12 @@ static ParseResult parseAffineMinOp(OpAsmParser &parser,
static void print(OpAsmPrinter &p, AffineMinOp op) { static void print(OpAsmPrinter &p, AffineMinOp op) {
p << op.getOperationName() << ' ' p << op.getOperationName() << ' '
<< op.getAttr(AffineMinOp::getMapAttrName()); << op.getAttr(AffineMinOp::getMapAttrName());
auto begin = op.operand_begin(); auto operands = op.getOperands();
auto end = op.operand_end();
unsigned numDims = op.map().getNumDims(); unsigned numDims = op.map().getNumDims();
p << '('; p << '(' << operands.take_front(numDims) << ')';
p.printOperands(begin, begin + numDims);
p << ')';
if (begin + numDims != end) { if (operands.size() != numDims)
p << '['; p << '[' << operands.drop_front(numDims) << ']';
p.printOperands(begin + numDims, end);
p << ']';
}
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
} }

View File

@ -289,8 +289,7 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
// Print the launch configuration. // Print the launch configuration.
p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword();
printSizeAssignment(p, op.getGridSize(), printSizeAssignment(p, op.getGridSize(), operands.take_front(3),
operands.drop_back(operands.size() - 3),
op.getBlockIds()); op.getBlockIds());
p << ' ' << op.getThreadsKeyword(); p << ' ' << op.getThreadsKeyword();
printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3), printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3),
@ -303,25 +302,17 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
// Print the data argument remapping. // Print the data argument remapping.
if (!op.body().empty() && !operands.empty()) { if (!op.body().empty() && !operands.empty()) {
p << ' ' << op.getArgsKeyword() << '('; p << ' ' << op.getArgsKeyword() << '(';
for (unsigned i = 0, e = operands.size(); i < e; ++i) { Block *entryBlock = &op.body().front();
if (i != 0) interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
p << ", "; p << *entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
p << *op.body().front().getArgument(LaunchOp::kNumConfigRegionAttributes +
i)
<< " = " << *operands[i]; << " = " << *operands[i];
} });
p << ") "; p << ") ";
} }
// Print the types of data arguments. // Print the types of data arguments.
if (!operands.empty()) { if (!operands.empty())
p << ": "; p << ": " << operands.getTypes();
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << operands[i]->getType();
}
}
p.printRegion(op.body(), /*printEntryBlockArgs=*/false); p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
@ -701,7 +692,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
return; return;
p << ' ' << keyword << '('; p << ' ' << keyword << '(';
interleaveComma(values, p.getStream(), interleaveComma(values, p,
[&p](BlockArgument *v) { p << *v << " : " << v->getType(); }); [&p](BlockArgument *v) { p << *v << " : " << v->getType(); });
p << ')'; p << ')';
} }

View File

@ -177,9 +177,8 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
SmallVector<Type, 8> types(op.getOperandTypes()); SmallVector<Type, 8> types(op.getOperandTypes());
auto funcTy = FunctionType::get(types, op.getType(), op.getContext()); auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
p << op.getOperationName() << ' ' << *op.base() << '['; p << op.getOperationName() << ' ' << *op.base() << '['
p.printOperands(std::next(op.operand_begin()), op.operand_end()); << op.getOperands().drop_front() << ']';
p << ']';
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : " << funcTy; p << " : " << funcTy;
} }
@ -312,10 +311,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
else else
p << *op.getOperand(0); p << *op.getOperand(0);
p << '('; p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
p.printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
p << ')';
p.printOptionalAttrDict(op.getAttrs(), {"callee"}); p.printOptionalAttrDict(op.getAttrs(), {"callee"});
// Reconstruct the function MLIR function type from operand and result types. // Reconstruct the function MLIR function type from operand and result types.
@ -938,8 +934,7 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
// Print the trailing type unless it's a string global. // Print the trailing type unless it's a string global.
if (op.getValueOrNull().dyn_cast_or_null<StringAttr>()) if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
return; return;
p << " : "; p << " : " << op.type();
p.printType(op.type());
Region &initializer = op.getInitializerRegion(); Region &initializer = op.getInitializerRegion();
if (!initializer.empty()) if (!initializer.empty())
@ -1346,8 +1341,7 @@ static LogicalResult verify(LLVMFuncOp op) {
static void printNullOp(OpAsmPrinter &p, LLVM::NullOp op) { static void printNullOp(OpAsmPrinter &p, LLVM::NullOp op) {
p << NullOp::getOperationName(); p << NullOp::getOperationName();
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : "; p << " : " << op.getType();
p.printType(op.getType());
} }
// <operation> = `llvm.mlir.null` : type // <operation> = `llvm.mlir.null` : type

View File

@ -37,18 +37,17 @@
#include "llvm/IR/Type.h" #include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
namespace mlir { using namespace mlir;
namespace NVVM { using namespace NVVM;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Printing/parsing for NVVM ops // Printing/parsing for NVVM ops
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << op->getName() << " "; p << op->getName() << " " << op->getOperands();
p.printOperands(op->getOperands());
if (op->getNumResults() > 0) if (op->getNumResults() > 0)
interleaveComma(op->getResultTypes(), p << " : "); p << " : " << op->getResultTypes();
} }
// <operation> ::= `llvm.nvvm.XYZ` : type // <operation> ::= `llvm.nvvm.XYZ` : type
@ -141,8 +140,7 @@ static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
} }
static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) { static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
p << op.getOperationName() << " "; p << op.getOperationName() << " " << op.getOperands();
p.printOperands(op.getOperands());
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : " p << " : "
<< FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()), << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
@ -210,10 +208,11 @@ NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
allowUnknownOperations(); allowUnknownOperations();
} }
namespace mlir {
namespace NVVM {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
static DialectRegistration<NVVMDialect> nvvmDialect;
} // namespace NVVM } // namespace NVVM
} // namespace mlir } // namespace mlir
static DialectRegistration<NVVMDialect> nvvmDialect;

View File

@ -36,18 +36,17 @@
#include "llvm/IR/Type.h" #include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
namespace mlir { using namespace mlir;
namespace ROCDL { using namespace ROCDL;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Printing/parsing for ROCDL ops // Printing/parsing for ROCDL ops
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void printROCDLOp(OpAsmPrinter &p, Operation *op) { static void printROCDLOp(OpAsmPrinter &p, Operation *op) {
p << op->getName() << " "; p << op->getName() << " " << op->getOperands();
p.printOperands(op->getOperands());
if (op->getNumResults() > 0) if (op->getNumResults() > 0)
interleaveComma(op->getResultTypes(), p << " : "); p << " : " << op->getResultTypes();
} }
// <operation> ::= `rocdl.XYZ` : type // <operation> ::= `rocdl.XYZ` : type
@ -73,10 +72,11 @@ ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) {
allowUnknownOperations(); allowUnknownOperations();
} }
namespace mlir {
namespace ROCDL {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc" #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
static DialectRegistration<ROCDLDialect> rocdlDialect;
} // namespace ROCDL } // namespace ROCDL
} // namespace mlir } // namespace mlir
static DialectRegistration<ROCDLDialect> rocdlDialect;

View File

@ -60,18 +60,16 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
llvm::StringSet<> linalgTraitAttrsSet; llvm::StringSet<> linalgTraitAttrsSet;
linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end()); linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs; SmallVector<NamedAttribute, 8> attrs;
for (auto attr : op.getAttrs()) { for (auto attr : op.getAttrs())
if (linalgTraitAttrsSet.count(attr.first.strref()) > 0) if (linalgTraitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr); attrs.push_back(attr);
}
auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " "; p << op.getOperationName() << " " << dictAttr << " " << op.getOperands();
p.printOperands(op.getOperands());
if (!op.region().empty()) if (!op.region().empty())
p.printRegion(op.region()); p.printRegion(op.region());
p.printOptionalAttrDict(op.getAttrs(), attrNames); p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << ": "; p << ": " << op.getOperandTypes();
interleaveComma(op.getOperandTypes(), p);
} }
static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); } static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
@ -342,14 +340,13 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
} }
static void print(OpAsmPrinter &p, SliceOp op) { static void print(OpAsmPrinter &p, SliceOp op) {
p << SliceOp::getOperationName() << " " << *op.view() << "["; auto indexings = op.indexings();
p.printOperands(op.indexings()); p << SliceOp::getOperationName() << " " << *op.view() << "[" << indexings
p << "] "; << "] ";
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getBaseViewType(); p << " : " << op.getBaseViewType();
for (auto indexing : op.indexings()) { if (!indexings.empty())
p << ", " << indexing->getType(); p << ", " << op.indexings().getTypes();
}
p << ", " << op.getType(); p << ", " << op.getType();
} }
@ -455,16 +452,11 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
static void print(OpAsmPrinter &p, YieldOp op) { static void print(OpAsmPrinter &p, YieldOp op) {
p << op.getOperationName(); p << op.getOperationName();
if (op.getNumOperands() > 0) { if (op.getNumOperands() > 0)
p << ' '; p << ' ' << op.getOperands();
p.printOperands(op.operand_begin(), op.operand_end());
}
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
if (op.getNumOperands() > 0) { if (op.getNumOperands() > 0)
p << " : "; p << " : " << op.getOperandTypes();
interleaveComma(op.getOperands(), p,
[&](Value *e) { p.printType(e->getType()); });
}
} }
static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
@ -536,12 +528,9 @@ static LogicalResult verify(YieldOp op) {
// Where %0, %1 and %2 are ssa-values of type MemRefType with strides. // Where %0, %1 and %2 are ssa-values of type MemRefType with strides.
static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) { static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation"); assert(op->getAbstractOperation() && "unregistered operation");
p << op->getName().getStringRef() << "("; p << op->getName().getStringRef() << "(" << op->getOperands() << ")";
interleaveComma(op->getOperands(), p, [&](Value *v) { p << *v; });
p << ")";
p.printOptionalAttrDict(op->getAttrs()); p.printOptionalAttrDict(op->getAttrs());
p << " : "; p << " : " << op->getOperandTypes();
interleaveComma(op->getOperands(), p, [&](Value *v) { p << v->getType(); });
} }
static ParseResult parseLinalgLibraryOp(OpAsmParser &parser, static ParseResult parseLinalgLibraryOp(OpAsmParser &parser,

View File

@ -500,11 +500,8 @@ static ParseResult parseBitFieldExtractOp(OpAsmParser &parser,
} }
static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) { static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) {
printer << op->getName() << ' '; printer << op->getName() << ' ' << op->getOperands() << " : "
printer.printOperands(op->getOperands()); << op->getOperandTypes();
printer << " : " << op->getOperand(0)->getType() << ", "
<< op->getOperand(1)->getType() << ", "
<< op->getOperand(2)->getType();
} }
static LogicalResult verifyBitFieldExtractOp(Operation *op) { static LogicalResult verifyBitFieldExtractOp(Operation *op) {
@ -580,9 +577,8 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
} }
static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
printer << logicalOp->getName() << ' '; printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
printer.printOperands(logicalOp->getOperands()); << logicalOp->getOperand(0)->getType();
printer << " : " << logicalOp->getOperand(0)->getType();
} }
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
@ -717,9 +713,7 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr() printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
<< '['; << '[' << op.indices() << "] : " << op.base_ptr()->getType();
printer.printOperands(op.indices());
printer << "] : " << op.base_ptr()->getType();
} }
static LogicalResult verify(spirv::AccessChainOp accessChainOp) { static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
@ -875,9 +869,8 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \"" printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
<< stringifyScope(atomOp.memory_scope()) << "\" \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
<< stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
<< stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "; << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
printer.printOperands(atomOp.getOperands()); << atomOp.getOperands() << " : " << atomOp.pointer()->getType();
printer << " : " << atomOp.pointer()->getType();
} }
static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) { static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
@ -975,9 +968,9 @@ static ParseResult parseBitFieldInsertOp(OpAsmParser &parser,
static void print(spirv::BitFieldInsertOp bitFieldInsertOp, static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
OpAsmPrinter &printer) { OpAsmPrinter &printer) {
printer << spirv::BitFieldInsertOp::getOperationName() << ' '; printer << spirv::BitFieldInsertOp::getOperationName() << ' '
printer.printOperands(bitFieldInsertOp.getOperands()); << bitFieldInsertOp.getOperands() << " : "
printer << " : " << bitFieldInsertOp.base()->getType() << ", " << bitFieldInsertOp.base()->getType() << ", "
<< bitFieldInsertOp.offset()->getType() << ", " << bitFieldInsertOp.offset()->getType() << ", "
<< bitFieldInsertOp.count()->getType(); << bitFieldInsertOp.count()->getType();
} }
@ -1072,8 +1065,8 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
} }
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) { static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
printer << spirv::BranchConditionalOp::getOperationName() << ' '; printer << spirv::BranchConditionalOp::getOperationName() << ' '
printer.printOperand(branchOp.condition()); << branchOp.condition();
if (auto weights = branchOp.branch_weights()) { if (auto weights = branchOp.branch_weights()) {
printer << " ["; printer << " [";
@ -1148,9 +1141,9 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
static void print(spirv::CompositeConstructOp compositeConstructOp, static void print(spirv::CompositeConstructOp compositeConstructOp,
OpAsmPrinter &printer) { OpAsmPrinter &printer) {
printer << spirv::CompositeConstructOp::getOperationName() << " "; printer << spirv::CompositeConstructOp::getOperationName() << " "
printer.printOperands(compositeConstructOp.constituents()); << compositeConstructOp.constituents() << " : "
printer << " : " << compositeConstructOp.getResult()->getType(); << compositeConstructOp.getResult()->getType();
} }
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
@ -1322,9 +1315,8 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) { static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value(); printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
if (constOp.getType().isa<spirv::ArrayType>()) { if (constOp.getType().isa<spirv::ArrayType>())
printer << " : " << constOp.getType(); printer << " : " << constOp.getType();
}
} }
static LogicalResult verify(spirv::ConstantOp constOp) { static LogicalResult verify(spirv::ConstantOp constOp) {
@ -1577,9 +1569,8 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
<< execModeOp.fn() << " \"" << execModeOp.fn() << " \""
<< stringifyExecutionMode(execModeOp.execution_mode()) << "\""; << stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
auto values = execModeOp.values(); auto values = execModeOp.values();
if (!values.size()) { if (!values.size())
return; return;
}
printer << ", "; printer << ", ";
interleaveComma(values, printer, [&](Attribute a) { interleaveComma(values, printer, [&](Attribute a) {
printer << a.cast<IntegerAttr>().getInt(); printer << a.cast<IntegerAttr>().getInt();
@ -1626,9 +1617,8 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext()); FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
printer << spirv::FunctionCallOp::getOperationName() << ' ' printer << spirv::FunctionCallOp::getOperationName() << ' '
<< functionCallOp.getAttr(kCallee) << '('; << functionCallOp.getAttr(kCallee) << '('
printer.printOperands(functionCallOp.arguments()); << functionCallOp.arguments() << ") : " << functionType;
printer << ") : " << functionType;
} }
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
@ -1829,9 +1819,8 @@ static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
static void print(spirv::GroupNonUniformBallotOp ballotOp, static void print(spirv::GroupNonUniformBallotOp ballotOp,
OpAsmPrinter &printer) { OpAsmPrinter &printer) {
printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \"" printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
<< stringifyScope(ballotOp.execution_scope()) << "\" "; << stringifyScope(ballotOp.execution_scope()) << "\" "
printer.printOperand(ballotOp.predicate()); << ballotOp.predicate() << " : " << ballotOp.getType();
printer << " : " << ballotOp.getType();
} }
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
@ -1943,9 +1932,8 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs; SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass( StringRef sc = stringifyStorageClass(
loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass()); loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "; printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
// Print the pointer operand. << loadOp.ptr();
printer.printOperand(loadOp.ptr());
printMemoryAccessAttribute(loadOp, printer, elidedAttrs); printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
@ -2238,26 +2226,26 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
} }
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
auto *op = moduleOp.getOperation(); printer << spirv::ModuleOp::getOperationName();
// Only print out addressing model and memory model in a nicer way if both // Only print out addressing model and memory model in a nicer way if both
// presents. Otherwise, print them in the general form. This helps debugging // presents. Otherwise, print them in the general form. This helps
// ill-formed ModuleOp. // debugging ill-formed ModuleOp.
SmallVector<StringRef, 2> elidedAttrs; SmallVector<StringRef, 2> elidedAttrs;
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>(); auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>(); auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
if (op->getAttr(addressingModelAttrName) && if (moduleOp.getAttr(addressingModelAttrName) &&
op->getAttr(memoryModelAttrName)) { moduleOp.getAttr(memoryModelAttrName)) {
printer << spirv::ModuleOp::getOperationName() << " \"" printer << " \""
<< spirv::stringifyAddressingModel(moduleOp.addressing_model()) << spirv::stringifyAddressingModel(moduleOp.addressing_model())
<< "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model()) << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
<< '"'; << '"';
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName}); elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
} }
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false); /*printBlockTerminators=*/false);
printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs); printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
} }
static LogicalResult verify(spirv::ModuleOp moduleOp) { static LogicalResult verify(spirv::ModuleOp moduleOp) {
@ -2417,9 +2405,8 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser,
} }
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) { static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
printer << spirv::ReturnValueOp::getOperationName() << ' '; printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
printer.printOperand(retValOp.value()); << " : " << retValOp.value()->getType();
printer << " : " << retValOp.value()->getType();
} }
static LogicalResult verify(spirv::ReturnValueOp retValOp) { static LogicalResult verify(spirv::ReturnValueOp retValOp) {
@ -2471,13 +2458,8 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
} }
static void print(spirv::SelectOp op, OpAsmPrinter &printer) { static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
printer << spirv::SelectOp::getOperationName() << " "; printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
<< " : " << op.condition()->getType() << ", "
// Print the operands.
printer.printOperands(op.getOperands());
// Print colon and types.
printer << " : " << op.condition()->getType() << ", "
<< op.result()->getType(); << op.result()->getType();
} }
@ -2788,8 +2770,7 @@ static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
printer.printSymbolName(constOp.sym_name()); printer.printSymbolName(constOp.sym_name());
if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName)) if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
printer << " = "; printer << " = " << constOp.default_value();
printer.printAttribute(constOp.default_value());
} }
static LogicalResult verify(spirv::SpecConstantOp constOp) { static LogicalResult verify(spirv::SpecConstantOp constOp) {
@ -2844,17 +2825,12 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs; SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass( StringRef sc = stringifyStorageClass(
storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass()); storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "; printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
// Print the pointer operand << storeOp.ptr() << ", " << storeOp.value();
printer.printOperand(storeOp.ptr());
printer << ", ";
// Print the value operand
printer.printOperand(storeOp.value());
printMemoryAccessAttribute(storeOp, printer, elidedAttrs); printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
printer << " : " << storeOp.value()->getType(); printer << " : " << storeOp.value()->getType();
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
} }
@ -2885,9 +2861,8 @@ static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
} }
static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) { static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '; printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '
printer.printOperand(ballotOp.predicate()); << ballotOp.predicate() << " : " << ballotOp.getType();
printer << " : " << ballotOp.getType();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2973,20 +2948,15 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
} }
static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) { static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
auto *op = varOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs{ SmallVector<StringRef, 4> elidedAttrs{
spirv::attributeName<spirv::StorageClass>()}; spirv::attributeName<spirv::StorageClass>()};
printer << spirv::VariableOp::getOperationName(); printer << spirv::VariableOp::getOperationName();
// Print optional initializer // Print optional initializer
if (op->getNumOperands() > 0) { if (varOp.getNumOperands() != 0)
printer << " init("; printer << " init(" << varOp.initializer() << ")";
printer.printOperands(varOp.initializer());
printer << ")";
}
printVariableDecorations(op, printer, elidedAttrs);
printVariableDecorations(varOp, printer, elidedAttrs);
printer << " : " << varOp.getType(); printer << " : " << varOp.getType();
} }

View File

@ -166,15 +166,10 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
void mlir::printDimAndSymbolList(Operation::operand_iterator begin, void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end, Operation::operand_iterator end,
unsigned numDims, OpAsmPrinter &p) { unsigned numDims, OpAsmPrinter &p) {
p << '('; Operation::operand_range operands(begin, end);
p.printOperands(begin, begin + numDims); p << '(' << operands.take_front(numDims) << ')';
p << ')'; if (operands.size() != numDims)
p << '[' << operands.drop_front(numDims) << ']';
if (begin + numDims != end) {
p << '[';
p.printOperands(begin + numDims, end);
p << ']';
}
} }
// Parses dimension and symbol list, and sets 'numDims' to the number of // Parses dimension and symbol list, and sets 'numDims' to the number of
@ -485,12 +480,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
} }
static void print(OpAsmPrinter &p, CallOp op) { static void print(OpAsmPrinter &p, CallOp op) {
p << "call " << op.getAttr("callee") << '('; p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
p.printOperands(op.getOperands());
p << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
p << " : "; p << " : " << op.getCalleeType();
p.printType(op.getCalleeType());
} }
static LogicalResult verify(CallOp op) { static LogicalResult verify(CallOp op) {
@ -572,11 +564,7 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser,
} }
static void print(OpAsmPrinter &p, CallIndirectOp op) { static void print(OpAsmPrinter &p, CallIndirectOp op) {
p << "call_indirect "; p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
p.printOperand(op.getCallee());
p << '(';
p.printOperands(op.getArgOperands());
p << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
p << " : " << op.getCallee()->getType(); p << " : " << op.getCallee()->getType();
} }
@ -690,12 +678,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
auto predicateValue = auto predicateValue =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt(); op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue)) p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
<< '"'; << '"' << ", " << op.lhs() << ", " << op.rhs();
p << ", ";
p.printOperand(op.lhs());
p << ", ";
p.printOperand(op.rhs());
p.printOptionalAttrDict(op.getAttrs(), p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
p << " : " << op.lhs()->getType(); p << " : " << op.lhs()->getType();
@ -851,15 +834,8 @@ static void print(OpAsmPrinter &p, CmpFOp op) {
assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) && assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) && predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
"unknown predicate index"); "unknown predicate index");
Builder b(op.getContext()); p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
auto predicateStringAttr = << ", " << op.rhs();
b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
p.printAttribute(predicateStringAttr);
p << ", ";
p.printOperand(op.lhs());
p << ", ";
p.printOperand(op.rhs());
p.printOptionalAttrDict(op.getAttrs(), p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
p << " : " << op.lhs()->getType(); p << " : " << op.lhs()->getType();
@ -1002,9 +978,7 @@ static ParseResult parseCondBranchOp(OpAsmParser &parser,
} }
static void print(OpAsmPrinter &p, CondBranchOp op) { static void print(OpAsmPrinter &p, CondBranchOp op) {
p << "cond_br "; p << "cond_br " << op.getCondition() << ", ";
p.printOperand(op.getCondition());
p << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
p << ", "; p << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
@ -1025,7 +999,7 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
if (op.getAttrs().size() > 1) if (op.getAttrs().size() > 1)
p << ' '; p << ' ';
p.printAttribute(op.getValue()); p << op.getValue();
// If the value is a symbol reference, print a trailing type. // If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>()) if (op.getValue().isa<SymbolRefAttr>())
@ -1407,18 +1381,12 @@ void DmaStartOp::build(Builder *builder, OperationState &result,
} }
void DmaStartOp::print(OpAsmPrinter &p) { void DmaStartOp::print(OpAsmPrinter &p) {
p << "dma_start " << *getSrcMemRef() << '['; p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], "
p.printOperands(getSrcIndices()); << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements()
p << "], " << *getDstMemRef() << '['; << ", " << *getTagMemRef() << '[' << getTagIndices() << ']';
p.printOperands(getDstIndices()); if (isStrided())
p << "], " << *getNumElements(); p << ", " << *getStride() << ", " << *getNumElementsPerStride();
p << ", " << *getTagMemRef() << '[';
p.printOperands(getTagIndices());
p << ']';
if (isStrided()) {
p << ", " << *getStride();
p << ", " << *getNumElementsPerStride();
}
p.printOptionalAttrDict(getAttrs()); p.printOptionalAttrDict(getAttrs());
p << " : " << getSrcMemRef()->getType(); p << " : " << getSrcMemRef()->getType();
p << ", " << getDstMemRef()->getType(); p << ", " << getDstMemRef()->getType();
@ -1550,12 +1518,8 @@ void DmaWaitOp::build(Builder *builder, OperationState &result,
} }
void DmaWaitOp::print(OpAsmPrinter &p) { void DmaWaitOp::print(OpAsmPrinter &p) {
p << "dma_wait "; p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
p.printOperand(getTagMemRef()); << getNumElements();
p << '[';
p.printOperands(getTagIndices());
p << "], ";
p.printOperand(getNumElements());
p.printOptionalAttrDict(getAttrs()); p.printOptionalAttrDict(getAttrs());
p << " : " << getTagMemRef()->getType(); p << " : " << getTagMemRef()->getType();
} }
@ -1604,8 +1568,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ExtractElementOp op) { static void print(OpAsmPrinter &p, ExtractElementOp op) {
p << "extract_element " << *op.getAggregate() << '['; p << "extract_element " << *op.getAggregate() << '[' << op.getIndices();
p.printOperands(op.getIndices());
p << ']'; p << ']';
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getAggregate()->getType(); p << " : " << op.getAggregate()->getType();
@ -1686,9 +1649,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, LoadOp op) { static void print(OpAsmPrinter &p, LoadOp op) {
p << "load " << *op.getMemRef() << '['; p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']';
p.printOperands(op.getIndices());
p << ']';
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType(); p << " : " << op.getMemRefType();
} }
@ -1922,12 +1883,8 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, ReturnOp op) { static void print(OpAsmPrinter &p, ReturnOp op) {
p << "return"; p << "return";
if (op.getNumOperands() != 0) { if (op.getNumOperands() != 0)
p << ' '; p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
p.printOperands(op.getOperands());
p << " : ";
interleaveComma(op.getOperandTypes(), p);
}
} }
static LogicalResult verify(ReturnOp op) { static LogicalResult verify(ReturnOp op) {
@ -1984,9 +1941,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
} }
static void print(OpAsmPrinter &p, SelectOp op) { static void print(OpAsmPrinter &p, SelectOp op) {
p << "select "; p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType();
p.printOperands(op.getOperands());
p << " : " << op.getTrueValue()->getType();
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
} }
@ -2093,9 +2048,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
static void print(OpAsmPrinter &p, StoreOp op) { static void print(OpAsmPrinter &p, StoreOp op) {
p << "store " << *op.getValueToStore(); p << "store " << *op.getValueToStore();
p << ", " << *op.getMemRef() << '['; p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']';
p.printOperands(op.getIndices());
p << ']';
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType(); p << " : " << op.getMemRefType();
} }
@ -2339,9 +2292,7 @@ static void print(OpAsmPrinter &p, ViewOp op) {
auto *dynamicOffset = op.getDynamicOffset(); auto *dynamicOffset = op.getDynamicOffset();
if (dynamicOffset != nullptr) if (dynamicOffset != nullptr)
p.printOperand(dynamicOffset); p.printOperand(dynamicOffset);
p << "]["; p << "][" << op.getDynamicSizes() << ']';
p.printOperands(op.getDynamicSizes());
p << ']';
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
} }
@ -2609,13 +2560,8 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
} }
static void print(OpAsmPrinter &p, SubViewOp op) { static void print(OpAsmPrinter &p, SubViewOp op) {
p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets()
p.printOperands(op.offsets()); << "][" << op.sizes() << "][" << op.strides() << ']';
p << "][";
p.printOperands(op.sizes());
p << "][";
p.printOperands(op.strides());
p << ']';
SmallVector<StringRef, 1> elidedAttrs = { SmallVector<StringRef, 1> elidedAttrs = {
SubViewOp::getOperandSegmentSizeAttr()}; SubViewOp::getOperandSegmentSizeAttr()};

View File

@ -110,17 +110,16 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
llvm::StringSet<> traitAttrsSet; llvm::StringSet<> traitAttrsSet;
traitAttrsSet.insert(attrNames.begin(), attrNames.end()); traitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs; SmallVector<NamedAttribute, 8> attrs;
for (auto attr : op.getAttrs()) { for (auto attr : op.getAttrs())
if (traitAttrsSet.count(attr.first.strref()) > 0) if (traitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr); attrs.push_back(attr);
}
auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", "; p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
p << *op.rhs() << ", " << *op.acc(); p << *op.rhs() << ", " << *op.acc();
if (llvm::size(op.masks()) == 2) { if (op.masks().size() == 2)
p << ", " << **op.masks().begin(); p << ", " << op.masks();
p << ", " << **(op.masks().begin() + 1);
}
p.printOptionalAttrDict(op.getAttrs(), attrNames); p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into " p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
<< op.getResultType(); << op.getResultType();
@ -417,9 +416,8 @@ static LogicalResult verify(vector::ExtractOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BroadcastOp op) { static void print(OpAsmPrinter &p, BroadcastOp op) {
p << op.getOperationName() << " " << *op.source(); p << op.getOperationName() << " " << *op.source() << " : "
p << " : " << op.getSourceType(); << op.getSourceType() << " to " << op.getVectorType();
p << " to " << op.getVectorType();
} }
static LogicalResult verify(BroadcastOp op) { static LogicalResult verify(BroadcastOp op) {
@ -560,8 +558,7 @@ static void print(OpAsmPrinter &p, InsertOp op) {
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
<< op.position(); << op.position();
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()}); p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
p << " : " << op.getSourceType(); p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
p << " into " << op.getDestVectorType();
} }
static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) { static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
@ -789,8 +786,8 @@ static LogicalResult verify(InsertStridedSliceOp op) {
static void print(OpAsmPrinter &p, OuterProductOp op) { static void print(OpAsmPrinter &p, OuterProductOp op) {
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs(); p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
if (llvm::size(op.acc()) > 0) if (!op.acc().empty())
p << ", " << **op.acc().begin(); p << ", " << op.acc();
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType(); p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
} }
@ -1034,16 +1031,10 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
} }
static void print(OpAsmPrinter &p, TransferReadOp op) { static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " "; p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
p.printOperand(op.memref()); << "], " << op.padding() << " ";
p << "[";
p.printOperands(op.indices());
p << "], ";
p.printOperand(op.padding());
p << " ";
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType(); p << " : " << op.getMemRefType() << ", " << op.getVectorType();
p << ", " << op.getVectorType();
} }
ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) { ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) {
@ -1106,15 +1097,10 @@ static LogicalResult verify(TransferReadOp op) {
// TransferWriteOp // TransferWriteOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TransferWriteOp op) { static void print(OpAsmPrinter &p, TransferWriteOp op) {
p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref(); p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref()
p << "["; << "[" << op.indices() << "]";
p.printOperands(op.indices());
p << "]";
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs());
p << " : "; p << " : " << op.getVectorType() << ", " << op.getMemRefType();
p.printType(op.getVectorType());
p << ", ";
p.printType(op.getMemRefType());
} }
ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) { ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) {
@ -1180,13 +1166,13 @@ void TypeCastOp::build(Builder *builder, OperationState &result,
inferVectorTypeCastResultType(source->getType().cast<MemRefType>())); inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
} }
static void print(OpAsmPrinter &p, TypeCastOp &op) { static void print(OpAsmPrinter &p, TypeCastOp op) {
auto type = op.getOperand()->getType().cast<MemRefType>(); auto type = op.getOperand()->getType().cast<MemRefType>();
p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to " p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
<< inferVectorTypeCastResultType(type); << inferVectorTypeCastResultType(type);
} }
static LogicalResult verify(TypeCastOp &op) { static LogicalResult verify(TypeCastOp op) {
auto resultType = inferVectorTypeCastResultType(op.getMemRefType()); auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
if (op.getResultMemRefType() != resultType) if (op.getResultMemRefType() != resultType)
return op.emitOpError("expects result type to be: ") << resultType; return op.emitOpError("expects result type to be: ") << resultType;
@ -1208,9 +1194,9 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(resultType, result.types)); parser.addTypeToList(resultType, result.types));
} }
static void print(OpAsmPrinter &p, ConstantMaskOp &op) { static void print(OpAsmPrinter &p, ConstantMaskOp op) {
p << op.getOperationName() << ' ' << op.mask_dim_sizes(); p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
p << " : " << op.getResult()->getType(); << op.getResult()->getType();
} }
static LogicalResult verify(ConstantMaskOp &op) { static LogicalResult verify(ConstantMaskOp &op) {
@ -1256,13 +1242,11 @@ ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(resultType, result.types)); parser.addTypeToList(resultType, result.types));
} }
static void print(OpAsmPrinter &p, CreateMaskOp &op) { static void print(OpAsmPrinter &p, CreateMaskOp op) {
p << op.getOperationName() << ' '; p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
p.printOperands(op.operands());
p << " : " << op.getResult()->getType();
} }
static LogicalResult verify(CreateMaskOp &op) { static LogicalResult verify(CreateMaskOp op) {
// Verify that an operand was specified for each result vector each dimension. // Verify that an operand was specified for each result vector each dimension.
if (op.getNumOperands() != if (op.getNumOperands() !=
op.getResult()->getType().cast<VectorType>().getRank()) op.getResult()->getType().cast<VectorType>().getRank())