[mlir][spirv] Clean up coop matrix assembly declaration.

Address code review feedback and use declarative assembly format.

Differential Revision: https://reviews.llvm.org/D80687
This commit is contained in:
Thomas Raoux 2020-05-29 16:34:56 -07:00
parent 7265ff928a
commit c652c306a6
4 changed files with 14 additions and 56 deletions

View File

@ -39,6 +39,8 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
``` ```
}]; }];
let assemblyFormat = "attr-dict `:` $type";
let availability = [ let availability = [
MinVersion<SPV_V_1_0>, MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>, MaxVersion<SPV_V_1_5>,
@ -139,7 +141,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
// ----- // -----
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV", def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
[NoSideEffect]> { [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_NV_cooperative_matrix"; let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{ let description = [{
@ -188,6 +190,10 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
``` ```
}]; }];
let assemblyFormat = [{
operands attr-dict`:` type($a) `,` type($b) `->` type($c)
}];
let availability = [ let availability = [
MinVersion<SPV_V_1_0>, MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>, MaxVersion<SPV_V_1_5>,

View File

@ -1134,12 +1134,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
return compositeConstructOp.emitError( return compositeConstructOp.emitError(
"has incorrect number of operands: expected ") "has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size(); << "1, but provided " << constituents.size();
} else { } else if (constituents.size() != cType.getNumElements()) {
if (constituents.size() != cType.getNumElements()) return compositeConstructOp.emitError(
return compositeConstructOp.emitError( "has incorrect number of operands: expected ")
"has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided "
<< cType.getNumElements() << ", but provided " << constituents.size();
<< constituents.size();
} }
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
@ -2735,57 +2734,10 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
printer << " : " << coopMatrix.getOperand(1).getType(); printer << " : " << coopMatrix.getOperand(1).getType();
} }
//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixLengthNV
//===----------------------------------------------------------------------===//
static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser,
OperationState &state) {
OpAsmParser::OperandType operandInfo;
Type dstType = parser.getBuilder().getIntegerType(32);
Type type;
if (parser.parseColonType(type)) {
return failure();
}
state.addAttribute(kTypeAttrName, TypeAttr::get(type));
state.addTypes(dstType);
return success();
}
static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix,
OpAsmPrinter &printer) {
printer << coopMatrix.getOperationName() << " : " << coopMatrix.type();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.CooperativeMatrixMulAddNV // spv.CooperativeMatrixMulAddNV
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser,
OperationState &state) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types(3);
if (parser.parseOperandList(ops, 3) || parser.parseColon() ||
parser.parseType(types[0]) || parser.parseComma() ||
parser.parseType(types[1]) || parser.parseArrow() ||
parser.parseType(types[2]) ||
parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) {
return failure();
}
state.addTypes(types[2]);
return success();
}
static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix,
OpAsmPrinter &printer) {
printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0)
<< ", " << coopMatrix.getOperand(1) << ", "
<< coopMatrix.getOperand(2) << ", "
<< " : " << coopMatrix.getOperand(0).getType() << ", "
<< coopMatrix.getOperand(1).getType() << " -> "
<< coopMatrix.getOperand(2).getType();
}
static LogicalResult static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
if (op.c().getType() != op.result().getType()) if (op.c().getType() != op.result().getType())

View File

@ -38,7 +38,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
// CHECK-LABEL: @cooperative_matrix_muladd // CHECK-LABEL: @cooperative_matrix_muladd
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" { spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return spv.Return
} }

View File

@ -38,7 +38,7 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
// CHECK-LABEL: @cooperative_matrix_muladd // CHECK-LABEL: @cooperative_matrix_muladd
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" { spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup> %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return spv.Return
} }