forked from OSchip/llvm-project
[mlir][spirv] First step to support spirv cooperative matrix extension.
Add a new type to SPIRV dialect for cooperative matrix and add new op for cooperative matrix load. This is missing most instructions to support cooperative matrix extension but this is a stop-gap patch to avoid creating big review. Differential Revision: https://reviews.llvm.org/D80043
This commit is contained in:
parent
2b59e9f1bd
commit
b359bbaa8b
|
@ -0,0 +1,41 @@
|
|||
//===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines utilities used for parsing types and ops for SPIR-V
|
||||
// dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_DIALECT_SPIRV_PARSERUTILS_H_
|
||||
#define MLIR_DIALECT_SPIRV_PARSERUTILS_H_
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
namespace mlir {
|
||||
/// Parses the next keyword in `parser` as an enumerant of the given
|
||||
/// `EnumClass`.
|
||||
template <typename EnumClass, typename ParserType>
|
||||
static ParseResult
|
||||
parseEnumKeywordAttr(EnumClass &value, ParserType &parser,
|
||||
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
||||
StringRef keyword;
|
||||
SmallVector<NamedAttribute, 1> attr;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return failure();
|
||||
if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>(keyword)) {
|
||||
value = attr.getValue();
|
||||
return success();
|
||||
}
|
||||
return parser.emitError(loc, "invalid ")
|
||||
<< attrName << " attribute specification: " << keyword;
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SPIRV_PARSERUTILS_H_
|
|
@ -2991,8 +2991,10 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
|
|||
AnyTypeOf<!foreach(w, widths, IOrUI<w>),
|
||||
StrJoinInt<widths, "/">.result # "-bit signless/unsigned integer">;
|
||||
|
||||
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
|
||||
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
|
||||
def SPV_IsCooperativeMatrixType :
|
||||
CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
|
||||
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
|
||||
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
|
||||
def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
|
||||
|
||||
|
@ -3012,6 +3014,9 @@ def SPV_AnyPtr : DialectType<SPIRV_Dialect, SPV_IsPtrType,
|
|||
"any SPIR-V pointer type">;
|
||||
def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
|
||||
"any SPIR-V array type">;
|
||||
def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
|
||||
SPV_IsCooperativeMatrixType,
|
||||
"any SPIR-V cooperative matrix type">;
|
||||
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
|
||||
"any SPIR-V runtime array type">;
|
||||
def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
|
||||
|
@ -3220,6 +3225,8 @@ def SPV_OC_OpGroupNonUniformSMax : I32EnumAttrCase<"OpGroupNonUniformSMax"
|
|||
def SPV_OC_OpGroupNonUniformUMax : I32EnumAttrCase<"OpGroupNonUniformUMax", 357>;
|
||||
def SPV_OC_OpGroupNonUniformFMax : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>;
|
||||
def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
|
||||
def SPV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
|
||||
def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
|
||||
|
||||
def SPV_OpcodeAttr :
|
||||
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
|
||||
|
@ -3271,7 +3278,8 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
|
||||
SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
|
||||
SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
|
||||
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR
|
||||
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
|
||||
SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV
|
||||
]>;
|
||||
|
||||
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
//===- SPIRVCooperativeMatrixOps.td - cooperative matmul ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the op definition spec of cooperative matrix multiply extension ops.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef SPIRV_COOPERATIVE_MATRIX_OPS
|
||||
#define SPIRV_COOPERATIVE_MATRIX_OPS
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
|
||||
let summary = "See extension SPV_NV_cooperative_matrix";
|
||||
|
||||
let description = [{
|
||||
Load a cooperative matrix through a pointer.
|
||||
|
||||
Result Type is the type of the loaded object. It must be a cooperative
|
||||
matrix type.
|
||||
|
||||
Pointer is a pointer into an array. Its type must be an OpTypePointer whose
|
||||
Type operand is a scalar or vector type. The storage class of Pointer must
|
||||
be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
|
||||
supported) PhysicalStorageBufferEXT.
|
||||
|
||||
Stride is the number of elements in the array in memory between the first
|
||||
component of consecutive rows (or columns) in the result. It must be a
|
||||
scalar integer type.
|
||||
|
||||
ColumnMajor indicates whether the values loaded from memory are arranged in
|
||||
column-major or row-major order. It must be a boolean constant instruction,
|
||||
with false indicating row major and true indicating column major.
|
||||
|
||||
Memory Access must be a Memory Access literal. If not present, it is the
|
||||
same as specifying None.
|
||||
|
||||
If ColumnMajor is false, then elements (row,*) of the result are taken in
|
||||
order from contiguous locations starting at Pointer[row*Stride]. If
|
||||
ColumnMajor is true, then elements (*,col) of the result are taken in order
|
||||
from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
|
||||
decoration on Pointer is ignored.
|
||||
|
||||
For a given dynamic instance of this instruction, all operands of this
|
||||
instruction must be the same for all invocations in a given scope instance
|
||||
(where the scope is the scope the cooperative matrix type was created with).
|
||||
All invocations in a given scope instance must be active or all must be
|
||||
inactive.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
cooperative-matrix-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
|
||||
storage-class ssa-use (`[` memory-access `]`)? `
|
||||
: ` cooperative-matrix-type
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor
|
||||
: !spv.coopmatrix<i32, Workgroup, 16, 8>
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[SPV_NV_cooperative_matrix]>,
|
||||
Capability<[SPV_C_CooperativeMatrixNV]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_AnyPtr:$pointer,
|
||||
SPV_Integer:$stride,
|
||||
SPV_Bool:$columnmajor,
|
||||
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_AnyCooperativeMatrix:$result
|
||||
);
|
||||
|
||||
let verifier = [{ return success(); }];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#endif // SPIRV_COOPERATIVE_MATRIX_OPS
|
|
@ -28,6 +28,7 @@ include "mlir/Dialect/SPIRV/SPIRVBitOps.td"
|
|||
include "mlir/Dialect/SPIRV/SPIRVCastOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
|
||||
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
|
||||
|
|
|
@ -54,6 +54,7 @@ SmallVector<Capability, 0> getRecursiveImpliedCapabilities(Capability cap);
|
|||
|
||||
namespace detail {
|
||||
struct ArrayTypeStorage;
|
||||
struct CooperativeMatrixTypeStorage;
|
||||
struct ImageTypeStorage;
|
||||
struct PointerTypeStorage;
|
||||
struct RuntimeArrayTypeStorage;
|
||||
|
@ -63,6 +64,7 @@ struct StructTypeStorage;
|
|||
namespace TypeKind {
|
||||
enum Kind {
|
||||
Array = Type::FIRST_SPIRV_TYPE,
|
||||
CooperativeMatrix,
|
||||
Image,
|
||||
Pointer,
|
||||
RuntimeArray,
|
||||
|
@ -330,6 +332,34 @@ public:
|
|||
Optional<spirv::StorageClass> storage = llvm::None);
|
||||
};
|
||||
|
||||
// SPIR-V cooperative matrix type
|
||||
class CooperativeMatrixNVType
|
||||
: public Type::TypeBase<CooperativeMatrixNVType, SPIRVType,
|
||||
detail::CooperativeMatrixTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == TypeKind::CooperativeMatrix;
|
||||
}
|
||||
|
||||
static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope,
|
||||
unsigned rows, unsigned columns);
|
||||
Type getElementType() const;
|
||||
|
||||
/// Return the scope of the cooperative matrix.
|
||||
spirv::Scope getScope() const;
|
||||
/// return the number of rows of the matrix.
|
||||
unsigned getRows() const;
|
||||
/// return the number of columns of the matrix.
|
||||
unsigned getColumns() const;
|
||||
|
||||
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<spirv::StorageClass> storage = llvm::None);
|
||||
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
Optional<spirv::StorageClass> storage = llvm::None);
|
||||
};
|
||||
|
||||
} // end namespace spirv
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/ParserUtils.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
|
||||
|
@ -115,7 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
|
|||
|
||||
SPIRVDialect::SPIRVDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
|
||||
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, PointerType,
|
||||
RuntimeArrayType, StructType>();
|
||||
|
||||
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
|
||||
|
||||
|
@ -264,6 +266,36 @@ static Type parseArrayType(SPIRVDialect const &dialect,
|
|||
return ArrayType::get(elementType, count, stride);
|
||||
}
|
||||
|
||||
// cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
|
||||
// rows ',' coloumns>`
|
||||
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
|
||||
DialectAsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
llvm::SMLoc countLoc = parser.getCurrentLocation();
|
||||
if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
|
||||
return Type();
|
||||
|
||||
if (dims.size() != 2) {
|
||||
parser.emitError(countLoc, "expected rows and columns size.");
|
||||
return Type();
|
||||
}
|
||||
|
||||
auto elementTy = parseAndVerifyType(dialect, parser);
|
||||
if (!elementTy)
|
||||
return Type();
|
||||
|
||||
Scope scope;
|
||||
if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
|
||||
return Type();
|
||||
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
|
||||
}
|
||||
|
||||
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
|
||||
// methods in alphabetical order
|
||||
//
|
||||
|
@ -525,6 +557,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
|
|||
|
||||
if (keyword == "array")
|
||||
return parseArrayType(*this, parser);
|
||||
if (keyword == "coopmatrix")
|
||||
return parseCooperativeMatrixType(*this, parser);
|
||||
if (keyword == "image")
|
||||
return parseImageType(*this, parser);
|
||||
if (keyword == "ptr")
|
||||
|
@ -595,11 +629,20 @@ static void print(StructType type, DialectAsmPrinter &os) {
|
|||
os << ">";
|
||||
}
|
||||
|
||||
static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
|
||||
os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
|
||||
os << type.getElementType() << ", " << stringifyScope(type.getScope());
|
||||
os << ">";
|
||||
}
|
||||
|
||||
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
switch (type.getKind()) {
|
||||
case TypeKind::Array:
|
||||
print(type.cast<ArrayType>(), os);
|
||||
return;
|
||||
case TypeKind::CooperativeMatrix:
|
||||
print(type.cast<CooperativeMatrixNVType>(), os);
|
||||
return;
|
||||
case TypeKind::Pointer:
|
||||
print(type.cast<PointerType>(), os);
|
||||
return;
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
|
||||
#include "mlir/Dialect/SPIRV/ParserUtils.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
|
@ -140,25 +141,6 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Parses the next keyword in `parser` as an enumerant of the given
|
||||
/// `EnumClass`.
|
||||
template <typename EnumClass>
|
||||
static ParseResult
|
||||
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
|
||||
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
||||
StringRef keyword;
|
||||
SmallVector<NamedAttribute, 1> attr;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return failure();
|
||||
if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>(keyword)) {
|
||||
value = attr.getValue();
|
||||
return success();
|
||||
}
|
||||
return parser.emitError(loc, "invalid ")
|
||||
<< attrName << " attribute specification: " << keyword;
|
||||
}
|
||||
|
||||
/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
|
||||
/// and inserts the enumerant into `state` as an 32-bit integer attribute with
|
||||
/// the enum class's name as attribute name.
|
||||
|
@ -2637,6 +2619,49 @@ static LogicalResult verify(spirv::VariableOp varOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.CooperativeMatrixLoadNV
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
spirv::StorageClass storageClass;
|
||||
SmallVector<OpAsmParser::OperandType, 3> operandInfo;
|
||||
Type strideType = parser.getBuilder().getIntegerType(32);
|
||||
Type columnMajorType = parser.getBuilder().getIntegerType(1);
|
||||
Type elementType;
|
||||
if (parseEnumStrAttr(storageClass, parser) ||
|
||||
parser.parseOperandList(operandInfo, 3) ||
|
||||
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
|
||||
parser.parseType(elementType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto ptrType = spirv::PointerType::get(
|
||||
elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
|
||||
storageClass);
|
||||
SmallVector<Type, 3> OperandType = {ptrType, strideType, columnMajorType};
|
||||
if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
|
||||
state.operands)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
state.addTypes(elementType);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
|
||||
StringRef sc = stringifyStorageClass(
|
||||
M.pointer().getType().cast<spirv::PointerType>().getStorageClass());
|
||||
printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc
|
||||
<< "\" " << M.pointer() << ", " << M.stride() << ", "
|
||||
<< M.columnmajor();
|
||||
// Print optional memory access attribute.
|
||||
if (auto memAccess = M.memory_access())
|
||||
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
|
||||
printer << " : " << M.getType();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
|
|
|
@ -158,6 +158,7 @@ void ArrayType::getCapabilities(
|
|||
bool CompositeType::classof(Type type) {
|
||||
switch (type.getKind()) {
|
||||
case TypeKind::Array:
|
||||
case TypeKind::CooperativeMatrix:
|
||||
case TypeKind::RuntimeArray:
|
||||
case TypeKind::Struct:
|
||||
return true;
|
||||
|
@ -177,6 +178,8 @@ Type CompositeType::getElementType(unsigned index) const {
|
|||
switch (getKind()) {
|
||||
case spirv::TypeKind::Array:
|
||||
return cast<ArrayType>().getElementType();
|
||||
case spirv::TypeKind::CooperativeMatrix:
|
||||
return cast<CooperativeMatrixNVType>().getElementType();
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
return cast<RuntimeArrayType>().getElementType();
|
||||
case spirv::TypeKind::Struct:
|
||||
|
@ -192,6 +195,9 @@ unsigned CompositeType::getNumElements() const {
|
|||
switch (getKind()) {
|
||||
case spirv::TypeKind::Array:
|
||||
return cast<ArrayType>().getNumElements();
|
||||
case spirv::TypeKind::CooperativeMatrix:
|
||||
return cast<CooperativeMatrixNVType>().getRows() *
|
||||
cast<CooperativeMatrixNVType>().getColumns();
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::RuntimeArray type");
|
||||
|
@ -211,6 +217,9 @@ void CompositeType::getExtensions(
|
|||
case spirv::TypeKind::Array:
|
||||
cast<ArrayType>().getExtensions(extensions, storage);
|
||||
break;
|
||||
case spirv::TypeKind::CooperativeMatrix:
|
||||
cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
|
||||
break;
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
cast<RuntimeArrayType>().getExtensions(extensions, storage);
|
||||
break;
|
||||
|
@ -233,6 +242,9 @@ void CompositeType::getCapabilities(
|
|||
case spirv::TypeKind::Array:
|
||||
cast<ArrayType>().getCapabilities(capabilities, storage);
|
||||
break;
|
||||
case spirv::TypeKind::CooperativeMatrix:
|
||||
cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
|
||||
break;
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
|
||||
break;
|
||||
|
@ -248,6 +260,70 @@ void CompositeType::getCapabilities(
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CooperativeMatrixType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
|
||||
using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
|
||||
|
||||
static CooperativeMatrixTypeStorage *
|
||||
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
|
||||
return new (allocator.allocate<CooperativeMatrixTypeStorage>())
|
||||
CooperativeMatrixTypeStorage(key);
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(elementType, getScope(), rows, columns);
|
||||
}
|
||||
|
||||
CooperativeMatrixTypeStorage(const KeyTy &key)
|
||||
: TypeStorage(static_cast<unsigned>(std::get<1>(key))),
|
||||
elementType(std::get<0>(key)), rows(std::get<2>(key)),
|
||||
columns(std::get<3>(key)) {}
|
||||
|
||||
Scope getScope() const { return static_cast<Scope>(getSubclassData()); }
|
||||
|
||||
Type elementType;
|
||||
unsigned rows;
|
||||
unsigned columns;
|
||||
};
|
||||
|
||||
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
|
||||
Scope scope, unsigned rows,
|
||||
unsigned columns) {
|
||||
return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
|
||||
elementType, scope, rows, columns);
|
||||
}
|
||||
|
||||
Type CooperativeMatrixNVType::getElementType() const {
|
||||
return getImpl()->elementType;
|
||||
}
|
||||
|
||||
Scope CooperativeMatrixNVType::getScope() const {
|
||||
return getImpl()->getScope();
|
||||
}
|
||||
|
||||
unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
|
||||
|
||||
unsigned CooperativeMatrixNVType::getColumns() const {
|
||||
return getImpl()->columns;
|
||||
}
|
||||
|
||||
void CooperativeMatrixNVType::getExtensions(
|
||||
SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<StorageClass> storage) {
|
||||
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
|
||||
extensions.push_back(Extension::SPV_NV_cooperative_matrix);
|
||||
}
|
||||
|
||||
void CooperativeMatrixNVType::getCapabilities(
|
||||
SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
Optional<StorageClass> storage) {
|
||||
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
|
||||
capabilities.push_back(Capability::CooperativeMatrixNV);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ImageType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -217,6 +217,8 @@ private:
|
|||
|
||||
LogicalResult processArrayType(ArrayRef<uint32_t> operands);
|
||||
|
||||
LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
|
||||
|
||||
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
|
||||
|
||||
LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
|
||||
|
@ -1160,6 +1162,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
|||
} break;
|
||||
case spirv::Opcode::OpTypeArray:
|
||||
return processArrayType(operands);
|
||||
case spirv::Opcode::OpTypeCooperativeMatrixNV:
|
||||
return processCooperativeMatrixType(operands);
|
||||
case spirv::Opcode::OpTypeFunction:
|
||||
return processFunctionType(operands);
|
||||
case spirv::Opcode::OpTypeRuntimeArray:
|
||||
|
@ -1229,6 +1233,35 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() != 5) {
|
||||
return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
|
||||
"type and row x column parameters");
|
||||
}
|
||||
|
||||
Type elementTy = getType(operands[1]);
|
||||
if (!elementTy) {
|
||||
return emitError(unknownLoc,
|
||||
"OpTypeCooperativeMatrix references undefined <id> ")
|
||||
<< operands[1];
|
||||
}
|
||||
|
||||
auto scope = spirv::symbolizeScope(operands[2]);
|
||||
if (!scope) {
|
||||
return emitError(unknownLoc,
|
||||
"OpTypeCooperativeMatrix references undefined scope <id> ")
|
||||
<< operands[2];
|
||||
}
|
||||
|
||||
unsigned rows = operands[3];
|
||||
unsigned columns = operands[4];
|
||||
|
||||
typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
|
||||
elementTy, scope.getValue(), rows, columns);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() != 2) {
|
||||
|
@ -2210,6 +2243,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
|||
case spirv::Opcode::OpTypeRuntimeArray:
|
||||
case spirv::Opcode::OpTypeStruct:
|
||||
case spirv::Opcode::OpTypePointer:
|
||||
case spirv::Opcode::OpTypeCooperativeMatrixNV:
|
||||
return processType(opcode, operands);
|
||||
case spirv::Opcode::OpConstant:
|
||||
return processConstant(operands, /*isSpec=*/false);
|
||||
|
|
|
@ -1096,6 +1096,21 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
|
|||
return success();
|
||||
}
|
||||
|
||||
if (auto cooperativeMatrixType =
|
||||
type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
|
||||
uint32_t elementTypeID = 0;
|
||||
if (failed(processType(loc, cooperativeMatrixType.getElementType(),
|
||||
elementTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
|
||||
operands.push_back(elementTypeID);
|
||||
operands.push_back(static_cast<uint32_t>(cooperativeMatrixType.getScope()));
|
||||
operands.push_back(cooperativeMatrixType.getRows());
|
||||
operands.push_back(cooperativeMatrixType.getColumns());
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(ravishankarm) : Handle other types.
|
||||
return emitError(loc, "unhandled type in serialization: ") << type;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
|
||||
// CHECK-LABEL: @cooperative_matrix_load
|
||||
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
|
||||
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
|
||||
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cooperative_matrix_load_memaccess
|
||||
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
|
||||
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
|
||||
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
|
||||
spv.Return
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @cooperative_matrix_load
|
||||
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
|
||||
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
|
||||
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @cooperative_matrix_load_memaccess
|
||||
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
|
||||
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
|
||||
%0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
|
||||
spv.Return
|
||||
}
|
|
@ -327,3 +327,23 @@ func @struct_type_missing_comma(!spv.struct<f32 [0 NonWritable], i32 [4]>)
|
|||
|
||||
// expected-error @+1 {{expected ']'}}
|
||||
func @struct_type_missing_comma(!spv.struct<f32 [0, NonWritable NonReadable], i32 [4]>)
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CooperativeMatrix
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>)
|
||||
func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected ','}}
|
||||
func @missing_scope(!spv.coopmatrix<8x16xi32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected rows and columns size}}
|
||||
func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> ()
|
||||
|
||||
|
|
Loading…
Reference in New Issue