[mlir] Update various operations to declaratively specify their assembly format.

Summary:
This revision switches over many operations to use the declarative methods for defining the assembly specification. This updates operations in the NVVM, ROCDL, Standard, and VectorOps dialects.

Differential Revision: https://reviews.llvm.org/D73407
This commit is contained in:
River Riddle 2020-01-30 11:31:50 -08:00
parent 1c158d0f90
commit 82170d5619
10 changed files with 34 additions and 312 deletions
mlir
include/mlir
Dialect
LLVMIR
StandardOps
VectorOps
IR
lib/Dialect
test/Dialect/LLVMIR

View File

@ -42,8 +42,7 @@ class NVVM_SpecialRegisterOp<string mnemonic,
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
string llvmBuilder = "$res = createIntrinsicCall(builder,"
# "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");";
let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }];
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
let assemblyFormat = "attr-dict `:` type($res)";
}
//===----------------------------------------------------------------------===//
@ -77,8 +76,7 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
}];
let parser = [{ return success(); }];
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
let assemblyFormat = "attr-dict";
}
def NVVM_ShflBflyOp :
@ -129,8 +127,7 @@ def NVVM_MmaOp :
$res = createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args);
}];
let parser = [{ return parseNVVMMmaOp(parser, result); }];
let printer = [{ printNVVMMmaOp(p, *this); }];
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
let verifier = [{ return ::verify(*this); }];
}

View File

@ -42,8 +42,7 @@ class ROCDL_SpecialRegisterOp<string mnemonic,
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
string llvmBuilder = "$res = createIntrinsicCall(builder,"
# "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");";
let parser = [{ return parseROCDLOp(parser, result); }];
let printer = [{ printROCDLOp(p, this->getOperation()); }];
let assemblyFormat = "attr-dict `:` type($res)";
}
class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
@ -52,8 +51,7 @@ class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
string llvmBuilder = "$res = createDeviceFunctionCall(builder, \""
# device_function # "\", " # parameter # ");";
let parser = [{ return parseROCDLOp(parser, result); }];
let printer = [{ printROCDLOp(p, this->getOperation()); }];
let assemblyFormat = "attr-dict `:` type($res)";
}
//===----------------------------------------------------------------------===//

View File

@ -304,6 +304,10 @@ def CallOp : Std_Op<"call", [CallOpInterface]> {
return getAttrOfType<SymbolRefAttr>("callee");
}
}];
let assemblyFormat = [{
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
}];
}
def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
@ -651,6 +655,7 @@ def DeallocOp : Std_Op<"dealloc"> {
let hasCanonicalizer = 1;
let hasFolder = 1;
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
def DimOp : Std_Op<"dim", [NoSideEffect]> {
@ -987,6 +992,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
}]>];
let hasFolder = 1;
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def RemFOp : FloatArithmeticOp<"remf"> {

View File

@ -203,6 +203,7 @@ def Vector_BroadcastOp :
return vector().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
}
def Vector_ShuffleOp :
@ -364,6 +365,10 @@ def Vector_ExtractSlicesOp :
static StringRef getSizesAttrName() { return "sizes"; }
static StringRef getStridesAttrName() { return "strides"; }
}];
let assemblyFormat = [{
$vector `,` $sizes `,` $strides attr-dict `:` type($vector) `into`
type(results)
}];
}
def Vector_InsertElementOp :
@ -482,6 +487,10 @@ def Vector_InsertSlicesOp :
static StringRef getSizesAttrName() { return "sizes"; }
static StringRef getStridesAttrName() { return "strides"; }
}];
let assemblyFormat = [{
$vectors `,` $sizes `,` $strides attr-dict `:` type($vectors) `into`
type(results)
}];
}
def Vector_InsertStridedSliceOp :
@ -727,6 +736,7 @@ def Vector_StridedSliceOp :
void getOffsets(SmallVectorImpl<int64_t> &results);
}];
let hasCanonicalizer = 1;
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
}
def Vector_TransferReadOp :
@ -939,6 +949,10 @@ def Vector_TransferWriteOp :
return memref().getType().cast<MemRefType>();
}
}];
let assemblyFormat = [{
$vector `,` $memref `[` $indices `]` attr-dict `:` type($vector) `,`
type($memref)
}];
}
def Vector_TypeCastOp :
@ -1017,6 +1031,7 @@ def Vector_ConstantMaskOp :
let extraClassDeclaration = [{
static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; }
}];
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
}
def Vector_CreateMaskOp :
@ -1048,6 +1063,7 @@ def Vector_CreateMaskOp :
}];
let hasCanonicalizer = 1;
let assemblyFormat = "$operands attr-dict `:` type(results)";
}
def Vector_TupleOp :
@ -1148,6 +1164,7 @@ def Vector_PrintOp :
return source().getType();
}
}];
let assemblyFormat = "$source attr-dict `:` type($source)";
}
#endif // VECTOR_OPS

View File

@ -312,7 +312,8 @@ class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
def AnyInteger : Type<CPred<"$_self.isa<IntegerType>()">, "integer">;
// Index type.
def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">;
def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">,
BuildableType<"getIndexType()">;
// Integer type of a specific width.
class I<int width>

View File

@ -41,18 +41,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << " : " << op->getResultTypes();
}
// <operation> ::= `llvm.nvvm.XYZ` : type
static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser &parser,
OperationState &result) {
Type type;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return failure();
result.addTypes(type);
return success();
}
static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
return parser.getBuilder()
.getContext()
@ -103,41 +91,6 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
parser.getNameLoc(), result.operands));
}
// <operation> ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...`
// : signature_type
static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 12> ops;
Type type;
llvm::SMLoc typeLoc;
if (parser.parseOperandList(ops) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&typeLoc) || parser.parseType(type)) {
return failure();
}
auto signature = type.dyn_cast<FunctionType>();
if (!signature) {
return parser.emitError(
typeLoc, "expected the type to be the full list of input and output");
}
if (signature.getNumResults() != 1) {
return parser.emitError(typeLoc, "expected single result");
}
return failure(parser.addTypeToList(signature.getResult(0), result.types) ||
parser.resolveOperands(ops, signature.getInputs(),
parser.getNameLoc(), result.operands));
}
static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
p << op.getOperationName() << " " << op.getOperands();
p.printOptionalAttrDict(op.getAttrs());
p << " : "
<< FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
op.getType(), op.getContext());
}
static LogicalResult verify(MmaOp op) {
auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);

View File

@ -30,24 +30,6 @@
using namespace mlir;
using namespace ROCDL;
//===----------------------------------------------------------------------===//
// Printing/parsing for ROCDL ops
//===----------------------------------------------------------------------===//
static void printROCDLOp(OpAsmPrinter &p, Operation *op) {
p << op->getName() << " " << op->getOperands();
if (op->getNumResults() > 0)
p << " : " << op->getResultTypes();
}
// <operation> ::= `rocdl.XYZ` : type
static ParseResult parseROCDLOp(OpAsmParser &parser, OperationState &result) {
Type type;
return failure(parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.addTypeToList(type, result.types));
}
//===----------------------------------------------------------------------===//
// ROCDLDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//

View File

@ -446,29 +446,6 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// CallOp
//===----------------------------------------------------------------------===//
static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
FlatSymbolRefAttr calleeAttr;
FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto calleeLoc = parser.getNameLoc();
if (parser.parseAttribute(calleeAttr, "callee", result.attributes) ||
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(calleeType) ||
parser.addTypesToList(calleeType.getResults(), result.types) ||
parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result.operands))
return failure();
return success();
}
static void print(OpAsmPrinter &p, CallOp op) {
p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
p << " : " << op.getCalleeType();
}
static LogicalResult verify(CallOp op) {
// Check that the callee attribute was specified.
auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
@ -1184,19 +1161,6 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
};
} // end anonymous namespace.
static void print(OpAsmPrinter &p, DeallocOp op) {
p << "dealloc " << op.memref() << " : " << op.memref().getType();
}
static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType memrefInfo;
MemRefType type;
return failure(parser.parseOperand(memrefInfo) ||
parser.parseColonType(type) ||
parser.resolveOperand(memrefInfo, type, result.operands));
}
static LogicalResult verify(DeallocOp op) {
if (!op.memref().getType().isa<MemRefType>())
return op.emitOpError("operand must be a memref");
@ -1844,20 +1808,6 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
// RankOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, RankOp op) {
p << "rank " << op.getOperand() << " : " << op.getOperand().getType();
}
static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType operandInfo;
Type type;
Type indexType = parser.getBuilder().getIndexType();
return failure(parser.parseOperand(operandInfo) ||
parser.parseColonType(type) ||
parser.resolveOperand(operandInfo, type, result.operands) ||
parser.addTypeToList(indexType, result.types));
}
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
// Constant fold rank when the rank of the tensor is known.
auto type = getOperand().getType();

View File

@ -474,38 +474,6 @@ void ExtractSlicesOp::build(Builder *builder, OperationState &result,
result.addAttribute(getStridesAttrName(), stridesAttr);
}
static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType operandInfo;
ArrayAttr sizesAttr;
StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName();
ArrayAttr stridesAttr;
StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName();
VectorType vectorType;
TupleType resultTupleType;
return failure(
parser.parseOperand(operandInfo) || parser.parseComma() ||
parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
parser.parseComma() ||
parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(vectorType) ||
parser.parseKeywordType("into", resultTupleType) ||
parser.resolveOperand(operandInfo, vectorType, result.operands) ||
parser.addTypeToList(resultTupleType, result.types));
}
static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
p << op.getOperationName() << ' ' << op.vector() << ", ";
p << op.sizes() << ", " << op.strides();
p.printOptionalAttrDict(
op.getAttrs(),
/*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
ExtractSlicesOp::getStridesAttrName()});
p << " : " << op.vector().getType();
p << " into " << op.getResultTupleType();
}
static LogicalResult
isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
TupleType tupleType, ArrayRef<int64_t> sizes,
@ -572,11 +540,6 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
// BroadcastOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BroadcastOp op) {
p << op.getOperationName() << " " << op.source() << " : "
<< op.getSourceType() << " to " << op.getVectorType();
}
static LogicalResult verify(BroadcastOp op) {
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
VectorType dstVectorType = op.getVectorType();
@ -601,18 +564,6 @@ static LogicalResult verify(BroadcastOp op) {
return success();
}
static ParseResult parseBroadcastOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType source;
Type sourceType;
VectorType vectorType;
return failure(parser.parseOperand(source) ||
parser.parseColonType(sourceType) ||
parser.parseKeywordType("to", vectorType) ||
parser.resolveOperand(source, sourceType, result.operands) ||
parser.addTypeToList(vectorType, result.types));
}
//===----------------------------------------------------------------------===//
// ShuffleOp
//===----------------------------------------------------------------------===//
@ -808,38 +759,6 @@ static LogicalResult verify(InsertOp op) {
// InsertSlicesOp
//===----------------------------------------------------------------------===//
static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType operandInfo;
ArrayAttr sizesAttr;
StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName();
ArrayAttr stridesAttr;
StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName();
TupleType tupleType;
VectorType resultVectorType;
return failure(
parser.parseOperand(operandInfo) || parser.parseComma() ||
parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
parser.parseComma() ||
parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(tupleType) ||
parser.parseKeywordType("into", resultVectorType) ||
parser.resolveOperand(operandInfo, tupleType, result.operands) ||
parser.addTypeToList(resultVectorType, result.types));
}
static void print(OpAsmPrinter &p, InsertSlicesOp op) {
p << op.getOperationName() << ' ' << op.vectors() << ", ";
p << op.sizes() << ", " << op.strides();
p.printOptionalAttrDict(
op.getAttrs(),
/*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
InsertSlicesOp::getStridesAttrName()});
p << " : " << op.vectors().getType();
p << " into " << op.getResultVectorType();
}
static LogicalResult verify(InsertSlicesOp op) {
SmallVector<int64_t, 4> sizes;
op.getSizes(sizes);
@ -1231,27 +1150,6 @@ void StridedSliceOp::build(Builder *builder, OperationState &result,
result.addAttribute(getStridesAttrName(), stridesAttr);
}
static void print(OpAsmPrinter &p, StridedSliceOp op) {
p << op.getOperationName() << " " << op.vector();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.vector().getType() << " to " << op.getResult().getType();
}
static ParseResult parseStridedSliceOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc attributeLoc, typeLoc;
OpAsmParser::OperandType vector;
VectorType vectorType, resultVectorType;
return failure(parser.parseOperand(vector) ||
parser.getCurrentLocation(&attributeLoc) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typeLoc) ||
parser.parseColonType(vectorType) ||
parser.parseKeywordType("to", resultVectorType) ||
parser.resolveOperand(vector, vectorType, result.operands) ||
parser.addTypeToList(resultVectorType, result.types));
}
static LogicalResult verify(StridedSliceOp op) {
auto type = op.getVectorType();
auto offsets = op.offsets();
@ -1519,35 +1417,6 @@ static LogicalResult verify(TransferReadOp op) {
//===----------------------------------------------------------------------===//
// TransferWriteOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TransferWriteOp op) {
p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
<< op.indices() << "]";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
}
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc typesLoc;
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memRefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
SmallVector<Type, 2> types;
if (parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(memRefInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
if (types.size() != 2)
return parser.emitError(typesLoc, "two types required");
auto indexType = parser.getBuilder().getIndexType();
Type vectorType = types[0], memRefType = types[1];
return failure(
parser.resolveOperand(storeValueInfo, vectorType, result.operands) ||
parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands));
}
static LogicalResult verify(TransferWriteOp op) {
// Consistency of elemental types in memref and vector.
@ -1676,23 +1545,6 @@ OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
// ConstantMaskOp
//===----------------------------------------------------------------------===//
static ParseResult parseConstantMaskOp(OpAsmParser &parser,
OperationState &result) {
Type resultType;
ArrayAttr maskDimSizesAttr;
StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
return failure(
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) ||
parser.parseColonType(resultType) ||
parser.addTypeToList(resultType, result.types));
}
static void print(OpAsmPrinter &p, ConstantMaskOp op) {
p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
<< op.getResult().getType();
}
static LogicalResult verify(ConstantMaskOp &op) {
// Verify that array attr size matches the rank of the vector result.
auto resultType = op.getResult().getType().cast<VectorType>();
@ -1724,23 +1576,6 @@ static LogicalResult verify(ConstantMaskOp &op) {
// CreateMaskOp
//===----------------------------------------------------------------------===//
static ParseResult parseCreateMaskOp(OpAsmParser &parser,
OperationState &result) {
auto indexType = parser.getBuilder().getIndexType();
Type resultType;
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
return failure(
parser.parseOperandList(operandInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(resultType) ||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
static void print(OpAsmPrinter &p, CreateMaskOp op) {
p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
}
static LogicalResult verify(CreateMaskOp op) {
// Verify that an operand was specified for each result vector each dimension.
if (op.getNumOperands() !=
@ -1750,23 +1585,6 @@ static LogicalResult verify(CreateMaskOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// PrintOp
//===----------------------------------------------------------------------===//
static ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType source;
Type sourceType;
return failure(parser.parseOperand(source) ||
parser.parseColonType(sourceType) ||
parser.resolveOperand(source, sourceType, result.operands));
}
static void print(OpAsmPrinter &p, PrintOp op) {
p << op.getOperationName() << ' ' << op.source() << " : "
<< op.getPrintType();
}
namespace {
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.

View File

@ -366,7 +366,7 @@ func @nvvm_invalid_mma_6(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
// expected-error@+1 {{expected the type to be the full list of input and output}}
// expected-error@+1 {{invalid kind of type specified}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm<"{ float, float, float, float, float, float, float, float }">
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
}
@ -378,9 +378,9 @@ func @nvvm_invalid_mma_7(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
// expected-error@+1 {{expected single result}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
// expected-error@+1 {{op requires one result}}
%0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
llvm.return %0#0 : !llvm<"{ float, float, float, float, float, float, float, float }">
}
// -----