forked from OSchip/llvm-project
[mlir][spirv] Clean up coop matrix assembly declaration.
Address code review feedback and use declarative assembly format. Differential Revision: https://reviews.llvm.org/D80687
This commit is contained in:
parent
7265ff928a
commit
c652c306a6
|
@ -39,6 +39,8 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
|
|||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = "attr-dict `:` $type";
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
|
@ -139,7 +141,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
|
|||
// -----
|
||||
|
||||
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||
|
||||
let description = [{
|
||||
|
@ -188,6 +190,10 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
|||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict`:` type($a) `,` type($b) `->` type($c)
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
|
|
|
@ -1134,12 +1134,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
|
|||
return compositeConstructOp.emitError(
|
||||
"has incorrect number of operands: expected ")
|
||||
<< "1, but provided " << constituents.size();
|
||||
} else {
|
||||
if (constituents.size() != cType.getNumElements())
|
||||
return compositeConstructOp.emitError(
|
||||
"has incorrect number of operands: expected ")
|
||||
<< cType.getNumElements() << ", but provided "
|
||||
<< constituents.size();
|
||||
} else if (constituents.size() != cType.getNumElements()) {
|
||||
return compositeConstructOp.emitError(
|
||||
"has incorrect number of operands: expected ")
|
||||
<< cType.getNumElements() << ", but provided "
|
||||
<< constituents.size();
|
||||
}
|
||||
|
||||
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
||||
|
@ -2735,57 +2734,10 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
|
|||
printer << " : " << coopMatrix.getOperand(1).getType();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.CooperativeMatrixLengthNV
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
Type dstType = parser.getBuilder().getIntegerType(32);
|
||||
Type type;
|
||||
if (parser.parseColonType(type)) {
|
||||
return failure();
|
||||
}
|
||||
state.addAttribute(kTypeAttrName, TypeAttr::get(type));
|
||||
state.addTypes(dstType);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix,
|
||||
OpAsmPrinter &printer) {
|
||||
printer << coopMatrix.getOperationName() << " : " << coopMatrix.type();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.CooperativeMatrixMulAddNV
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> ops;
|
||||
SmallVector<Type, 3> types(3);
|
||||
if (parser.parseOperandList(ops, 3) || parser.parseColon() ||
|
||||
parser.parseType(types[0]) || parser.parseComma() ||
|
||||
parser.parseType(types[1]) || parser.parseArrow() ||
|
||||
parser.parseType(types[2]) ||
|
||||
parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) {
|
||||
return failure();
|
||||
}
|
||||
state.addTypes(types[2]);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix,
|
||||
OpAsmPrinter &printer) {
|
||||
printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0)
|
||||
<< ", " << coopMatrix.getOperand(1) << ", "
|
||||
<< coopMatrix.getOperand(2) << ", "
|
||||
<< " : " << coopMatrix.getOperand(0).getType() << ", "
|
||||
<< coopMatrix.getOperand(1).getType() << " -> "
|
||||
<< coopMatrix.getOperand(2).getType();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
||||
if (op.c().getType() != op.result().getType())
|
||||
|
|
|
@ -38,7 +38,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
|
|||
|
||||
// CHECK-LABEL: @cooperative_matrix_muladd
|
||||
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
|
|||
|
||||
// CHECK-LABEL: @cooperative_matrix_muladd
|
||||
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
|
||||
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue