forked from OSchip/llvm-project
[mlir][spirv] Clean up coop matrix assembly declaration.
Address code review feedback and use declarative assembly format. Differential Revision: https://reviews.llvm.org/D80687
This commit is contained in:
parent
7265ff928a
commit
c652c306a6
|
@ -39,6 +39,8 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict `:` $type";
|
||||||
|
|
||||||
let availability = [
|
let availability = [
|
||||||
MinVersion<SPV_V_1_0>,
|
MinVersion<SPV_V_1_0>,
|
||||||
MaxVersion<SPV_V_1_5>,
|
MaxVersion<SPV_V_1_5>,
|
||||||
|
@ -139,7 +141,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
|
||||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -188,6 +190,10 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
operands attr-dict`:` type($a) `,` type($b) `->` type($c)
|
||||||
|
}];
|
||||||
|
|
||||||
let availability = [
|
let availability = [
|
||||||
MinVersion<SPV_V_1_0>,
|
MinVersion<SPV_V_1_0>,
|
||||||
MaxVersion<SPV_V_1_5>,
|
MaxVersion<SPV_V_1_5>,
|
||||||
|
|
|
@ -1134,12 +1134,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
|
||||||
return compositeConstructOp.emitError(
|
return compositeConstructOp.emitError(
|
||||||
"has incorrect number of operands: expected ")
|
"has incorrect number of operands: expected ")
|
||||||
<< "1, but provided " << constituents.size();
|
<< "1, but provided " << constituents.size();
|
||||||
} else {
|
} else if (constituents.size() != cType.getNumElements()) {
|
||||||
if (constituents.size() != cType.getNumElements())
|
return compositeConstructOp.emitError(
|
||||||
return compositeConstructOp.emitError(
|
"has incorrect number of operands: expected ")
|
||||||
"has incorrect number of operands: expected ")
|
<< cType.getNumElements() << ", but provided "
|
||||||
<< cType.getNumElements() << ", but provided "
|
<< constituents.size();
|
||||||
<< constituents.size();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
||||||
|
@ -2735,57 +2734,10 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
|
||||||
printer << " : " << coopMatrix.getOperand(1).getType();
|
printer << " : " << coopMatrix.getOperand(1).getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// spv.CooperativeMatrixLengthNV
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser,
|
|
||||||
OperationState &state) {
|
|
||||||
OpAsmParser::OperandType operandInfo;
|
|
||||||
Type dstType = parser.getBuilder().getIntegerType(32);
|
|
||||||
Type type;
|
|
||||||
if (parser.parseColonType(type)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
state.addAttribute(kTypeAttrName, TypeAttr::get(type));
|
|
||||||
state.addTypes(dstType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix,
|
|
||||||
OpAsmPrinter &printer) {
|
|
||||||
printer << coopMatrix.getOperationName() << " : " << coopMatrix.type();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.CooperativeMatrixMulAddNV
|
// spv.CooperativeMatrixMulAddNV
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser,
|
|
||||||
OperationState &state) {
|
|
||||||
SmallVector<OpAsmParser::OperandType, 3> ops;
|
|
||||||
SmallVector<Type, 3> types(3);
|
|
||||||
if (parser.parseOperandList(ops, 3) || parser.parseColon() ||
|
|
||||||
parser.parseType(types[0]) || parser.parseComma() ||
|
|
||||||
parser.parseType(types[1]) || parser.parseArrow() ||
|
|
||||||
parser.parseType(types[2]) ||
|
|
||||||
parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
state.addTypes(types[2]);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix,
|
|
||||||
OpAsmPrinter &printer) {
|
|
||||||
printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0)
|
|
||||||
<< ", " << coopMatrix.getOperand(1) << ", "
|
|
||||||
<< coopMatrix.getOperand(2) << ", "
|
|
||||||
<< " : " << coopMatrix.getOperand(0).getType() << ", "
|
|
||||||
<< coopMatrix.getOperand(1).getType() << " -> "
|
|
||||||
<< coopMatrix.getOperand(2).getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
||||||
if (op.c().getType() != op.result().getType())
|
if (op.c().getType() != op.result().getType())
|
||||||
|
|
|
@ -38,7 +38,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
|
||||||
|
|
||||||
// CHECK-LABEL: @cooperative_matrix_muladd
|
// CHECK-LABEL: @cooperative_matrix_muladd
|
||||||
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
|
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
|
||||||
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||||
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||||
spv.Return
|
spv.Return
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
|
||||||
|
|
||||||
// CHECK-LABEL: @cooperative_matrix_muladd
|
// CHECK-LABEL: @cooperative_matrix_muladd
|
||||||
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
|
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
|
||||||
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||||
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
|
||||||
spv.Return
|
spv.Return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue