[mlir][spirv] First step to support spirv cooperative matrix extension.

Add a new type to SPIRV dialect for cooperative matrix and add new op for
cooperative matrix load. This is missing most instructions to support
cooperative matrix extension but this is a stop-gap patch to avoid creating big
review.

Differential Revision: https://reviews.llvm.org/D80043
This commit is contained in:
Thomas Raoux 2020-05-19 19:07:21 -07:00
parent 2b59e9f1bd
commit b359bbaa8b
13 changed files with 442 additions and 22 deletions

View File

@ -0,0 +1,41 @@
//===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines utilities used for parsing types and ops for SPIR-V
// dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_PARSERUTILS_H_
#define MLIR_DIALECT_SPIRV_PARSERUTILS_H_
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
namespace mlir {
/// Parses the next keyword in `parser` as an enumerant of the given
/// `EnumClass`.
template <typename EnumClass, typename ParserType>
static ParseResult
parseEnumKeywordAttr(EnumClass &value, ParserType &parser,
StringRef attrName = spirv::attributeName<EnumClass>()) {
StringRef keyword;
SmallVector<NamedAttribute, 1> attr;
auto loc = parser.getCurrentLocation();
if (parser.parseKeyword(&keyword))
return failure();
if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>(keyword)) {
value = attr.getValue();
return success();
}
return parser.emitError(loc, "invalid ")
<< attrName << " attribute specification: " << keyword;
}
} // namespace mlir
#endif // MLIR_DIALECT_SPIRV_PARSERUTILS_H_

View File

@ -2991,8 +2991,10 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, IOrUI<w>),
StrJoinInt<widths, "/">.result # "-bit signless/unsigned integer">;
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
def SPV_IsCooperativeMatrixType :
CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
@ -3012,6 +3014,9 @@ def SPV_AnyPtr : DialectType<SPIRV_Dialect, SPV_IsPtrType,
"any SPIR-V pointer type">;
def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
"any SPIR-V array type">;
def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
SPV_IsCooperativeMatrixType,
"any SPIR-V cooperative matrix type">;
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
"any SPIR-V runtime array type">;
def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
@ -3220,6 +3225,8 @@ def SPV_OC_OpGroupNonUniformSMax : I32EnumAttrCase<"OpGroupNonUniformSMax"
def SPV_OC_OpGroupNonUniformUMax : I32EnumAttrCase<"OpGroupNonUniformUMax", 357>;
def SPV_OC_OpGroupNonUniformFMax : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>;
def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
def SPV_OpcodeAttr :
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@ -3271,7 +3278,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

View File

@ -0,0 +1,94 @@
//===- SPIRVCooperativeMatrixOps.td - cooperative matmul ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the op definition spec of cooperative matrix multiply extension ops.
//
//===----------------------------------------------------------------------===//
#ifndef SPIRV_COOPERATIVE_MATRIX_OPS
#define SPIRV_COOPERATIVE_MATRIX_OPS
// -----
def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{
Load a cooperative matrix through a pointer.
Result Type is the type of the loaded object. It must be a cooperative
matrix type.
Pointer is a pointer into an array. Its type must be an OpTypePointer whose
Type operand is a scalar or vector type. The storage class of Pointer must
be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
supported) PhysicalStorageBufferEXT.
Stride is the number of elements in the array in memory between the first
component of consecutive rows (or columns) in the result. It must be a
scalar integer type.
ColumnMajor indicates whether the values loaded from memory are arranged in
column-major or row-major order. It must be a boolean constant instruction,
with false indicating row major and true indicating column major.
Memory Access must be a Memory Access literal. If not present, it is the
same as specifying None.
If ColumnMajor is false, then elements (row,*) of the result are taken in
order from contiguous locations starting at Pointer[row*Stride]. If
ColumnMajor is true, then elements (*,col) of the result are taken in order
from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
decoration on Pointer is ignored.
For a given dynamic instance of this instruction, all operands of this
instruction must be the same for all invocations in a given scope instance
(where the scope is the scope the cooperative matrix type was created with).
All invocations in a given scope instance must be active or all must be
inactive.
### Custom assembly form
``` {.ebnf}
cooperative-matrix-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
storage-class ssa-use (`[` memory-access `]`)? `
: ` cooperative-matrix-type
```
For example:
```
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor
: !spv.coopmatrix<i32, Workgroup, 16, 8>
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[SPV_NV_cooperative_matrix]>,
Capability<[SPV_C_CooperativeMatrixNV]>
];
let arguments = (ins
SPV_AnyPtr:$pointer,
SPV_Integer:$stride,
SPV_Bool:$columnmajor,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access
);
let results = (outs
SPV_AnyCooperativeMatrix:$result
);
let verifier = [{ return success(); }];
}
// -----
#endif // SPIRV_COOPERATIVE_MATRIX_OPS

View File

@ -28,6 +28,7 @@ include "mlir/Dialect/SPIRV/SPIRVBitOps.td"
include "mlir/Dialect/SPIRV/SPIRVCastOps.td"
include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td"
include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td"
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"

View File

@ -54,6 +54,7 @@ SmallVector<Capability, 0> getRecursiveImpliedCapabilities(Capability cap);
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
@ -63,6 +64,7 @@ struct StructTypeStorage;
namespace TypeKind {
enum Kind {
Array = Type::FIRST_SPIRV_TYPE,
CooperativeMatrix,
Image,
Pointer,
RuntimeArray,
@ -330,6 +332,34 @@ public:
Optional<spirv::StorageClass> storage = llvm::None);
};
// SPIR-V cooperative matrix type
class CooperativeMatrixNVType
: public Type::TypeBase<CooperativeMatrixNVType, SPIRVType,
detail::CooperativeMatrixTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) {
return kind == TypeKind::CooperativeMatrix;
}
static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope,
unsigned rows, unsigned columns);
Type getElementType() const;
/// Return the scope of the cooperative matrix.
spirv::Scope getScope() const;
/// return the number of rows of the matrix.
unsigned getRows() const;
/// return the number of columns of the matrix.
unsigned getColumns() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<spirv::StorageClass> storage = llvm::None);
};
} // end namespace spirv
} // end namespace mlir

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/ParserUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
@ -115,7 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, PointerType,
RuntimeArrayType, StructType>();
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
@ -264,6 +266,36 @@ static Type parseArrayType(SPIRVDialect const &dialect,
return ArrayType::get(elementType, count, stride);
}
// cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
// rows ',' coloumns>`
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
SmallVector<int64_t, 2> dims;
llvm::SMLoc countLoc = parser.getCurrentLocation();
if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
return Type();
if (dims.size() != 2) {
parser.emitError(countLoc, "expected rows and columns size.");
return Type();
}
auto elementTy = parseAndVerifyType(dialect, parser);
if (!elementTy)
return Type();
Scope scope;
if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
return Type();
if (parser.parseGreater())
return Type();
return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
}
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@ -525,6 +557,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "array")
return parseArrayType(*this, parser);
if (keyword == "coopmatrix")
return parseCooperativeMatrixType(*this, parser);
if (keyword == "image")
return parseImageType(*this, parser);
if (keyword == "ptr")
@ -595,11 +629,20 @@ static void print(StructType type, DialectAsmPrinter &os) {
os << ">";
}
static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", " << stringifyScope(type.getScope());
os << ">";
}
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case TypeKind::Array:
print(type.cast<ArrayType>(), os);
return;
case TypeKind::CooperativeMatrix:
print(type.cast<CooperativeMatrixNVType>(), os);
return;
case TypeKind::Pointer:
print(type.cast<PointerType>(), os);
return;

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/ParserUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
@ -140,25 +141,6 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
return success();
}
/// Parses the next keyword in `parser` as an enumerant of the given
/// `EnumClass`.
template <typename EnumClass>
static ParseResult
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
StringRef attrName = spirv::attributeName<EnumClass>()) {
StringRef keyword;
SmallVector<NamedAttribute, 1> attr;
auto loc = parser.getCurrentLocation();
if (parser.parseKeyword(&keyword))
return failure();
if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>(keyword)) {
value = attr.getValue();
return success();
}
return parser.emitError(loc, "invalid ")
<< attrName << " attribute specification: " << keyword;
}
/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
/// and inserts the enumerant into `state` as an 32-bit integer attribute with
/// the enum class's name as attribute name.
@ -2637,6 +2619,49 @@ static LogicalResult verify(spirv::VariableOp varOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixLoadNV
//===----------------------------------------------------------------------===//
static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
OperationState &state) {
spirv::StorageClass storageClass;
SmallVector<OpAsmParser::OperandType, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Type columnMajorType = parser.getBuilder().getIntegerType(1);
Type elementType;
if (parseEnumStrAttr(storageClass, parser) ||
parser.parseOperandList(operandInfo, 3) ||
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
parser.parseType(elementType)) {
return failure();
}
auto ptrType = spirv::PointerType::get(
elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
storageClass);
SmallVector<Type, 3> OperandType = {ptrType, strideType, columnMajorType};
if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
state.operands)) {
return failure();
}
state.addTypes(elementType);
return success();
}
static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
StringRef sc = stringifyStorageClass(
M.pointer().getType().cast<spirv::PointerType>().getStorageClass());
printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc
<< "\" " << M.pointer() << ", " << M.stride() << ", "
<< M.columnmajor();
// Print optional memory access attribute.
if (auto memAccess = M.memory_access())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
printer << " : " << M.getType();
}
namespace mlir {
namespace spirv {

View File

@ -158,6 +158,7 @@ void ArrayType::getCapabilities(
bool CompositeType::classof(Type type) {
switch (type.getKind()) {
case TypeKind::Array:
case TypeKind::CooperativeMatrix:
case TypeKind::RuntimeArray:
case TypeKind::Struct:
return true;
@ -177,6 +178,8 @@ Type CompositeType::getElementType(unsigned index) const {
switch (getKind()) {
case spirv::TypeKind::Array:
return cast<ArrayType>().getElementType();
case spirv::TypeKind::CooperativeMatrix:
return cast<CooperativeMatrixNVType>().getElementType();
case spirv::TypeKind::RuntimeArray:
return cast<RuntimeArrayType>().getElementType();
case spirv::TypeKind::Struct:
@ -192,6 +195,9 @@ unsigned CompositeType::getNumElements() const {
switch (getKind()) {
case spirv::TypeKind::Array:
return cast<ArrayType>().getNumElements();
case spirv::TypeKind::CooperativeMatrix:
return cast<CooperativeMatrixNVType>().getRows() *
cast<CooperativeMatrixNVType>().getColumns();
case spirv::TypeKind::RuntimeArray:
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
@ -211,6 +217,9 @@ void CompositeType::getExtensions(
case spirv::TypeKind::Array:
cast<ArrayType>().getExtensions(extensions, storage);
break;
case spirv::TypeKind::CooperativeMatrix:
cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
break;
case spirv::TypeKind::RuntimeArray:
cast<RuntimeArrayType>().getExtensions(extensions, storage);
break;
@ -233,6 +242,9 @@ void CompositeType::getCapabilities(
case spirv::TypeKind::Array:
cast<ArrayType>().getCapabilities(capabilities, storage);
break;
case spirv::TypeKind::CooperativeMatrix:
cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
break;
case spirv::TypeKind::RuntimeArray:
cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
break;
@ -248,6 +260,70 @@ void CompositeType::getCapabilities(
}
}
//===----------------------------------------------------------------------===//
// CooperativeMatrixType
//===----------------------------------------------------------------------===//
struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
static CooperativeMatrixTypeStorage *
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
return new (allocator.allocate<CooperativeMatrixTypeStorage>())
CooperativeMatrixTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, getScope(), rows, columns);
}
CooperativeMatrixTypeStorage(const KeyTy &key)
: TypeStorage(static_cast<unsigned>(std::get<1>(key))),
elementType(std::get<0>(key)), rows(std::get<2>(key)),
columns(std::get<3>(key)) {}
Scope getScope() const { return static_cast<Scope>(getSubclassData()); }
Type elementType;
unsigned rows;
unsigned columns;
};
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
Scope scope, unsigned rows,
unsigned columns) {
return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
elementType, scope, rows, columns);
}
Type CooperativeMatrixNVType::getElementType() const {
return getImpl()->elementType;
}
Scope CooperativeMatrixNVType::getScope() const {
return getImpl()->getScope();
}
unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
unsigned CooperativeMatrixNVType::getColumns() const {
return getImpl()->columns;
}
void CooperativeMatrixNVType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
extensions.push_back(Extension::SPV_NV_cooperative_matrix);
}
void CooperativeMatrixNVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
capabilities.push_back(Capability::CooperativeMatrixNV);
}
//===----------------------------------------------------------------------===//
// ImageType
//===----------------------------------------------------------------------===//

View File

@ -217,6 +217,8 @@ private:
LogicalResult processArrayType(ArrayRef<uint32_t> operands);
LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
@ -1160,6 +1162,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
} break;
case spirv::Opcode::OpTypeArray:
return processArrayType(operands);
case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processCooperativeMatrixType(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeRuntimeArray:
@ -1229,6 +1233,35 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult
Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
if (operands.size() != 5) {
return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
"type and row x column parameters");
}
Type elementTy = getType(operands[1]);
if (!elementTy) {
return emitError(unknownLoc,
"OpTypeCooperativeMatrix references undefined <id> ")
<< operands[1];
}
auto scope = spirv::symbolizeScope(operands[2]);
if (!scope) {
return emitError(unknownLoc,
"OpTypeCooperativeMatrix references undefined scope <id> ")
<< operands[2];
}
unsigned rows = operands[3];
unsigned columns = operands[4];
typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
elementTy, scope.getValue(), rows, columns);
return success();
}
LogicalResult
Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
@ -2210,6 +2243,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processType(opcode, operands);
case spirv::Opcode::OpConstant:
return processConstant(operands, /*isSpec=*/false);

View File

@ -1096,6 +1096,21 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return success();
}
if (auto cooperativeMatrixType =
type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
uint32_t elementTypeID = 0;
if (failed(processType(loc, cooperativeMatrixType.getElementType(),
elementTypeID))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
operands.push_back(elementTypeID);
operands.push_back(static_cast<uint32_t>(cooperativeMatrixType.getScope()));
operands.push_back(cooperativeMatrixType.getRows());
operands.push_back(cooperativeMatrixType.getColumns());
return success();
}
// TODO(ravishankarm) : Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}

View File

@ -0,0 +1,17 @@
// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
// CHECK-LABEL: @cooperative_matrix_load
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
}

View File

@ -0,0 +1,16 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: @cooperative_matrix_load
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}
// -----
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}

View File

@ -327,3 +327,23 @@ func @struct_type_missing_comma(!spv.struct<f32 [0 NonWritable], i32 [4]>)
// expected-error @+1 {{expected ']'}}
func @struct_type_missing_comma(!spv.struct<f32 [0, NonWritable NonReadable], i32 [4]>)
// -----
//===----------------------------------------------------------------------===//
// CooperativeMatrix
//===----------------------------------------------------------------------===//
// CHECK: func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>)
func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>) -> ()
// -----
// expected-error @+1 {{expected ','}}
func @missing_scope(!spv.coopmatrix<8x16xi32>) -> ()
// -----
// expected-error @+1 {{expected rows and columns size}}
func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> ()