diff --git a/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h b/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h new file mode 100644 index 000000000000..f368aec45efb --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h @@ -0,0 +1,41 @@ +//===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utilities used for parsing types and ops for SPIR-V +// dialect. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_SPIRV_PARSERUTILS_H_ +#define MLIR_DIALECT_SPIRV_PARSERUTILS_H_ + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +/// Parses the next keyword in `parser` as an enumerant of the given +/// `EnumClass`. +template +static ParseResult +parseEnumKeywordAttr(EnumClass &value, ParserType &parser, + StringRef attrName = spirv::attributeName()) { + StringRef keyword; + SmallVector attr; + auto loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&keyword)) + return failure(); + if (Optional attr = spirv::symbolizeEnum(keyword)) { + value = attr.getValue(); + return success(); + } + return parser.emitError(loc, "invalid ") + << attrName << " attribute specification: " << keyword; +} +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_PARSERUTILS_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 64063cb77d01..b958a10c5952 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2991,8 +2991,10 @@ class SignlessOrUnsignedIntOfWidths widths> : AnyTypeOf), StrJoinInt.result # "-bit signless/unsigned integer">; -def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; +def SPV_IsCooperativeMatrixType : + CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">; +def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; @@ -3012,6 +3014,9 @@ def SPV_AnyPtr : DialectType; def SPV_AnyArray : DialectType; +def SPV_AnyCooperativeMatrix : DialectType; def SPV_AnyRTArray : DialectType; def SPV_AnyStruct : DialectType; def SPV_OC_OpGroupNonUniformFMax : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>; def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; +def SPV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>; +def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>; def SPV_OpcodeAttr : SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ @@ -3271,7 +3278,8 @@ def SPV_OpcodeAttr : SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax, - SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR + SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, + SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td new file mode 100644 index 000000000000..931f56f58755 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td @@ -0,0 +1,94 @@ +//===- SPIRVCooperativeMatrixOps.td - cooperative matmul ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of cooperative matrix multiply extension ops. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_COOPERATIVE_MATRIX_OPS +#define SPIRV_COOPERATIVE_MATRIX_OPS + +// ----- + +def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> { + let summary = "See extension SPV_NV_cooperative_matrix"; + + let description = [{ + Load a cooperative matrix through a pointer. + + Result Type is the type of the loaded object. It must be a cooperative + matrix type. + + Pointer is a pointer into an array. Its type must be an OpTypePointer whose + Type operand is a scalar or vector type. The storage class of Pointer must + be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is + supported) PhysicalStorageBufferEXT. + + Stride is the number of elements in the array in memory between the first + component of consecutive rows (or columns) in the result. It must be a + scalar integer type. + + ColumnMajor indicates whether the values loaded from memory are arranged in + column-major or row-major order. It must be a boolean constant instruction, + with false indicating row major and true indicating column major. + + Memory Access must be a Memory Access literal. If not present, it is the + same as specifying None. + + If ColumnMajor is false, then elements (row,*) of the result are taken in + order from contiguous locations starting at Pointer[row*Stride]. If + ColumnMajor is true, then elements (*,col) of the result are taken in order + from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride + decoration on Pointer is ignored. + + For a given dynamic instance of this instruction, all operands of this + instruction must be the same for all invocations in a given scope instance + (where the scope is the scope the cooperative matrix type was created with). + All invocations in a given scope instance must be active or all must be + inactive. + + ### Custom assembly form + + ``` {.ebnf} + cooperative-matrix-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV` + storage-class ssa-use (`[` memory-access `]`)? ` + : ` cooperative-matrix-type + ``` + + For example: + + ``` + %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor + : !spv.coopmatrix + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_NV_cooperative_matrix]>, + Capability<[SPV_C_CooperativeMatrixNV]> + ]; + + let arguments = (ins + SPV_AnyPtr:$pointer, + SPV_Integer:$stride, + SPV_Bool:$columnmajor, + OptionalAttr:$memory_access + ); + + let results = (outs + SPV_AnyCooperativeMatrix:$result + ); + + let verifier = [{ return success(); }]; +} + +// ----- + +#endif // SPIRV_COOPERATIVE_MATRIX_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 518dca69873d..520ed14c9624 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -28,6 +28,7 @@ include "mlir/Dialect/SPIRV/SPIRVBitOps.td" include "mlir/Dialect/SPIRV/SPIRVCastOps.td" include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td" include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td" +include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td" include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index 3b5a82d239b9..078fb5a67225 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -54,6 +54,7 @@ SmallVector getRecursiveImpliedCapabilities(Capability cap); namespace detail { struct ArrayTypeStorage; +struct CooperativeMatrixTypeStorage; struct ImageTypeStorage; struct PointerTypeStorage; struct RuntimeArrayTypeStorage; @@ -63,6 +64,7 @@ struct StructTypeStorage; namespace TypeKind { enum Kind { Array = Type::FIRST_SPIRV_TYPE, + CooperativeMatrix, Image, Pointer, RuntimeArray, @@ -330,6 +332,34 @@ public: Optional storage = llvm::None); }; +// SPIR-V cooperative matrix type +class CooperativeMatrixNVType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == TypeKind::CooperativeMatrix; + } + + static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope, + unsigned rows, unsigned columns); + Type getElementType() const; + + /// Return the scope of the cooperative matrix. + spirv::Scope getScope() const; + /// return the number of rows of the matrix. + unsigned getRows() const; + /// return the number of columns of the matrix. + unsigned getColumns() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); +}; + } // end namespace spirv } // end namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index c74698a93bfc..8c4d0ebe99a7 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/ParserUtils.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/TargetAndABI.h" @@ -115,7 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addTypes(); + addTypes(); addAttributes(); @@ -264,6 +266,36 @@ static Type parseArrayType(SPIRVDialect const &dialect, return ArrayType::get(elementType, count, stride); } +// cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ',' +// rows ',' coloumns>` +static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dims; + llvm::SMLoc countLoc = parser.getCurrentLocation(); + if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) + return Type(); + + if (dims.size() != 2) { + parser.emitError(countLoc, "expected rows and columns size."); + return Type(); + } + + auto elementTy = parseAndVerifyType(dialect, parser); + if (!elementTy) + return Type(); + + Scope scope; + if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope ")) + return Type(); + + if (parser.parseGreater()) + return Type(); + return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]); +} + // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type // methods in alphabetical order // @@ -525,6 +557,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const { if (keyword == "array") return parseArrayType(*this, parser); + if (keyword == "coopmatrix") + return parseCooperativeMatrixType(*this, parser); if (keyword == "image") return parseImageType(*this, parser); if (keyword == "ptr") @@ -595,11 +629,20 @@ static void print(StructType type, DialectAsmPrinter &os) { os << ">"; } +static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { + os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; + os << type.getElementType() << ", " << stringifyScope(type.getScope()); + os << ">"; +} + void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { switch (type.getKind()) { case TypeKind::Array: print(type.cast(), os); return; + case TypeKind::CooperativeMatrix: + print(type.cast(), os); + return; case TypeKind::Pointer: print(type.cast(), os); return; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index e7bdfe902804..eed597b1d21c 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/ParserUtils.h" #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" @@ -140,25 +141,6 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, return success(); } -/// Parses the next keyword in `parser` as an enumerant of the given -/// `EnumClass`. -template -static ParseResult -parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, - StringRef attrName = spirv::attributeName()) { - StringRef keyword; - SmallVector attr; - auto loc = parser.getCurrentLocation(); - if (parser.parseKeyword(&keyword)) - return failure(); - if (Optional attr = spirv::symbolizeEnum(keyword)) { - value = attr.getValue(); - return success(); - } - return parser.emitError(loc, "invalid ") - << attrName << " attribute specification: " << keyword; -} - /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` /// and inserts the enumerant into `state` as an 32-bit integer attribute with /// the enum class's name as attribute name. @@ -2637,6 +2619,49 @@ static LogicalResult verify(spirv::VariableOp varOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv.CooperativeMatrixLoadNV +//===----------------------------------------------------------------------===// + +static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser, + OperationState &state) { + spirv::StorageClass storageClass; + SmallVector operandInfo; + Type strideType = parser.getBuilder().getIntegerType(32); + Type columnMajorType = parser.getBuilder().getIntegerType(1); + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || + parser.parseOperandList(operandInfo, 3) || + parseMemoryAccessAttributes(parser, state) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get( + elementType.cast().getElementType(), + storageClass); + SmallVector OperandType = {ptrType, strideType, columnMajorType}; + if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(), + state.operands)) { + return failure(); + } + + state.addTypes(elementType); + return success(); +} + +static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) { + StringRef sc = stringifyStorageClass( + M.pointer().getType().cast().getStorageClass()); + printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc + << "\" " << M.pointer() << ", " << M.stride() << ", " + << M.columnmajor(); + // Print optional memory access attribute. + if (auto memAccess = M.memory_access()) + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; + printer << " : " << M.getType(); +} + namespace mlir { namespace spirv { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index 71ca0c3d2bc7..ce5a6c0c4fd9 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -158,6 +158,7 @@ void ArrayType::getCapabilities( bool CompositeType::classof(Type type) { switch (type.getKind()) { case TypeKind::Array: + case TypeKind::CooperativeMatrix: case TypeKind::RuntimeArray: case TypeKind::Struct: return true; @@ -177,6 +178,8 @@ Type CompositeType::getElementType(unsigned index) const { switch (getKind()) { case spirv::TypeKind::Array: return cast().getElementType(); + case spirv::TypeKind::CooperativeMatrix: + return cast().getElementType(); case spirv::TypeKind::RuntimeArray: return cast().getElementType(); case spirv::TypeKind::Struct: @@ -192,6 +195,9 @@ unsigned CompositeType::getNumElements() const { switch (getKind()) { case spirv::TypeKind::Array: return cast().getNumElements(); + case spirv::TypeKind::CooperativeMatrix: + return cast().getRows() * + cast().getColumns(); case spirv::TypeKind::RuntimeArray: llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -211,6 +217,9 @@ void CompositeType::getExtensions( case spirv::TypeKind::Array: cast().getExtensions(extensions, storage); break; + case spirv::TypeKind::CooperativeMatrix: + cast().getExtensions(extensions, storage); + break; case spirv::TypeKind::RuntimeArray: cast().getExtensions(extensions, storage); break; @@ -233,6 +242,9 @@ void CompositeType::getCapabilities( case spirv::TypeKind::Array: cast().getCapabilities(capabilities, storage); break; + case spirv::TypeKind::CooperativeMatrix: + cast().getCapabilities(capabilities, storage); + break; case spirv::TypeKind::RuntimeArray: cast().getCapabilities(capabilities, storage); break; @@ -248,6 +260,70 @@ void CompositeType::getCapabilities( } } +//===----------------------------------------------------------------------===// +// CooperativeMatrixType +//===----------------------------------------------------------------------===// + +struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage { + using KeyTy = std::tuple; + + static CooperativeMatrixTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + CooperativeMatrixTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(elementType, getScope(), rows, columns); + } + + CooperativeMatrixTypeStorage(const KeyTy &key) + : TypeStorage(static_cast(std::get<1>(key))), + elementType(std::get<0>(key)), rows(std::get<2>(key)), + columns(std::get<3>(key)) {} + + Scope getScope() const { return static_cast(getSubclassData()); } + + Type elementType; + unsigned rows; + unsigned columns; +}; + +CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType, + Scope scope, unsigned rows, + unsigned columns) { + return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix, + elementType, scope, rows, columns); +} + +Type CooperativeMatrixNVType::getElementType() const { + return getImpl()->elementType; +} + +Scope CooperativeMatrixNVType::getScope() const { + return getImpl()->getScope(); +} + +unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; } + +unsigned CooperativeMatrixNVType::getColumns() const { + return getImpl()->columns; +} + +void CooperativeMatrixNVType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + getElementType().cast().getExtensions(extensions, storage); + extensions.push_back(Extension::SPV_NV_cooperative_matrix); +} + +void CooperativeMatrixNVType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + getElementType().cast().getCapabilities(capabilities, storage); + capabilities.push_back(Capability::CooperativeMatrixNV); +} + //===----------------------------------------------------------------------===// // ImageType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index a45780ba63a0..87f233580b75 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -217,6 +217,8 @@ private: LogicalResult processArrayType(ArrayRef operands); + LogicalResult processCooperativeMatrixType(ArrayRef operands); + LogicalResult processFunctionType(ArrayRef operands); LogicalResult processRuntimeArrayType(ArrayRef operands); @@ -1160,6 +1162,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, } break; case spirv::Opcode::OpTypeArray: return processArrayType(operands); + case spirv::Opcode::OpTypeCooperativeMatrixNV: + return processCooperativeMatrixType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); case spirv::Opcode::OpTypeRuntimeArray: @@ -1229,6 +1233,35 @@ LogicalResult Deserializer::processFunctionType(ArrayRef operands) { return success(); } +LogicalResult +Deserializer::processCooperativeMatrixType(ArrayRef operands) { + if (operands.size() != 5) { + return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " + "type and row x column parameters"); + } + + Type elementTy = getType(operands[1]); + if (!elementTy) { + return emitError(unknownLoc, + "OpTypeCooperativeMatrix references undefined ") + << operands[1]; + } + + auto scope = spirv::symbolizeScope(operands[2]); + if (!scope) { + return emitError(unknownLoc, + "OpTypeCooperativeMatrix references undefined scope ") + << operands[2]; + } + + unsigned rows = operands[3]; + unsigned columns = operands[4]; + + typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( + elementTy, scope.getValue(), rows, columns); + return success(); +} + LogicalResult Deserializer::processRuntimeArrayType(ArrayRef operands) { if (operands.size() != 2) { @@ -2210,6 +2243,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: + case spirv::Opcode::OpTypeCooperativeMatrixNV: return processType(opcode, operands); case spirv::Opcode::OpConstant: return processConstant(operands, /*isSpec=*/false); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 2b500ddbf985..8ea0c4f4711b 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1096,6 +1096,21 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, return success(); } + if (auto cooperativeMatrixType = + type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processType(loc, cooperativeMatrixType.getElementType(), + elementTypeID))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; + operands.push_back(elementTypeID); + operands.push_back(static_cast(cooperativeMatrixType.getScope())); + operands.push_back(cooperativeMatrixType.getRows()); + operands.push_back(cooperativeMatrixType.getColumns()); + return success(); + } + // TODO(ravishankarm) : Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir new file mode 100644 index 000000000000..e90996ee24b7 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK-LABEL: @cooperative_matrix_load + spv.func @cooperative_matrix_load(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup> + %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup> + spv.Return + } + + // CHECK-LABEL: @cooperative_matrix_load_memaccess + spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return + } +} diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir new file mode 100644 index 000000000000..c121943acf82 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @cooperative_matrix_load +spv.func @cooperative_matrix_load(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup> + %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup> + spv.Return +} + +// ----- +// CHECK-LABEL: @cooperative_matrix_load_memaccess +spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir index 4c1adafce4a8..697177b0b98e 100644 --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -327,3 +327,23 @@ func @struct_type_missing_comma(!spv.struct) // expected-error @+1 {{expected ']'}} func @struct_type_missing_comma(!spv.struct) + +// ----- + +//===----------------------------------------------------------------------===// +// CooperativeMatrix +//===----------------------------------------------------------------------===// + +// CHECK: func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>) +func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>) -> () + +// ----- + +// expected-error @+1 {{expected ','}} +func @missing_scope(!spv.coopmatrix<8x16xi32>) -> () + +// ----- + +// expected-error @+1 {{expected rows and columns size}} +func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> () +