forked from OSchip/llvm-project
[mlir][spirv] Add support for matrix type
This commit adds basic matrix type support to the SPIR-V dialect including type definition, IR assembly, parsing, printing, and (de)serialization. Differential Revision: https://reviews.llvm.org/D80594
This commit is contained in:
parent
971459c3ef
commit
915e55c910
|
@ -3109,6 +3109,7 @@ def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>;
|
|||
def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
|
||||
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
|
||||
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
|
||||
def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>;
|
||||
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
|
||||
def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
|
||||
def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>;
|
||||
|
@ -3250,15 +3251,15 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
|
||||
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
|
||||
SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
|
||||
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
|
||||
SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
|
||||
SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
|
||||
SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
|
||||
SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
|
||||
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
|
||||
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
|
||||
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
|
||||
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
|
||||
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix,
|
||||
SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
|
||||
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
|
||||
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
|
||||
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
|
||||
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
|
||||
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
|
||||
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
|
||||
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
|
||||
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
|
||||
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
|
||||
#define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
|
||||
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
@ -56,9 +58,11 @@ namespace detail {
|
|||
struct ArrayTypeStorage;
|
||||
struct CooperativeMatrixTypeStorage;
|
||||
struct ImageTypeStorage;
|
||||
struct MatrixTypeStorage;
|
||||
struct PointerTypeStorage;
|
||||
struct RuntimeArrayTypeStorage;
|
||||
struct StructTypeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
namespace TypeKind {
|
||||
|
@ -66,6 +70,7 @@ enum Kind {
|
|||
Array = Type::FIRST_SPIRV_TYPE,
|
||||
CooperativeMatrix,
|
||||
Image,
|
||||
Matrix,
|
||||
Pointer,
|
||||
RuntimeArray,
|
||||
Struct,
|
||||
|
@ -366,6 +371,36 @@ public:
|
|||
Optional<spirv::StorageClass> storage = llvm::None);
|
||||
};
|
||||
|
||||
// SPIR-V matrix type
|
||||
class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
|
||||
detail::MatrixTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == TypeKind::Matrix; }
|
||||
|
||||
static MatrixType get(Type columnType, uint32_t columnCount);
|
||||
|
||||
static MatrixType getChecked(Type columnType, uint32_t columnCount,
|
||||
Location location);
|
||||
|
||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||
Type columnType,
|
||||
uint32_t columnCount);
|
||||
|
||||
/// Returns true if the matrix elements are vectors of float elements
|
||||
static bool isValidColumnType(Type columnType);
|
||||
|
||||
Type getElementType() const;
|
||||
|
||||
unsigned getNumElements() const;
|
||||
|
||||
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<spirv::StorageClass> storage = llvm::None);
|
||||
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
Optional<spirv::StorageClass> storage = llvm::None);
|
||||
};
|
||||
|
||||
} // end namespace spirv
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -116,8 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
|
|||
|
||||
SPIRVDialect::SPIRVDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, PointerType,
|
||||
RuntimeArrayType, StructType>();
|
||||
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
|
||||
PointerType, RuntimeArrayType, StructType>();
|
||||
|
||||
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
|
||||
|
||||
|
@ -197,6 +197,42 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
|
|||
return type;
|
||||
}
|
||||
|
||||
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
|
||||
DialectAsmParser &parser) {
|
||||
Type type;
|
||||
llvm::SMLoc typeLoc = parser.getCurrentLocation();
|
||||
if (parser.parseType(type))
|
||||
return Type();
|
||||
|
||||
if (auto t = type.dyn_cast<VectorType>()) {
|
||||
if (t.getRank() != 1) {
|
||||
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
|
||||
return Type();
|
||||
}
|
||||
if (t.getNumElements() > 4 || t.getNumElements() < 2) {
|
||||
parser.emitError(typeLoc,
|
||||
"matrix columns size has to be less than or equal "
|
||||
"to 4 and greater than or equal 2, but found ")
|
||||
<< t.getNumElements();
|
||||
return Type();
|
||||
}
|
||||
|
||||
if (!t.getElementType().isa<FloatType>()) {
|
||||
parser.emitError(typeLoc, "matrix columns' elements must be of "
|
||||
"Float type, got ")
|
||||
<< t.getElementType();
|
||||
return Type();
|
||||
}
|
||||
} else {
|
||||
parser.emitError(typeLoc, "matrix must be composed using vector "
|
||||
"type, got ")
|
||||
<< type;
|
||||
return Type();
|
||||
}
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
/// Parses an optional `, stride = N` assembly segment. If no parsing failure
|
||||
/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
|
||||
/// missing.
|
||||
|
@ -279,7 +315,7 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
|
|||
return Type();
|
||||
|
||||
if (dims.size() != 2) {
|
||||
parser.emitError(countLoc, "expected rows and columns size.");
|
||||
parser.emitError(countLoc, "expected rows and columns size");
|
||||
return Type();
|
||||
}
|
||||
|
||||
|
@ -350,6 +386,40 @@ static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
|
|||
return RuntimeArrayType::get(elementType, stride);
|
||||
}
|
||||
|
||||
// matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
|
||||
static Type parseMatrixType(SPIRVDialect const &dialect,
|
||||
DialectAsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
|
||||
SmallVector<int64_t, 1> countDims;
|
||||
llvm::SMLoc countLoc = parser.getCurrentLocation();
|
||||
if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
|
||||
return Type();
|
||||
if (countDims.size() != 1) {
|
||||
parser.emitError(countLoc, "expected single unsigned "
|
||||
"integer for number of columns");
|
||||
return Type();
|
||||
}
|
||||
|
||||
int64_t columnCount = countDims[0];
|
||||
// According to the specification, Matrices can have 2, 3, or 4 columns
|
||||
if (columnCount < 2 || columnCount > 4) {
|
||||
parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
|
||||
"columns");
|
||||
return Type();
|
||||
}
|
||||
|
||||
Type columnType = parseAndVerifyMatrixType(dialect, parser);
|
||||
if (!columnType)
|
||||
return Type();
|
||||
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
|
||||
return MatrixType::get(columnType, columnCount);
|
||||
}
|
||||
|
||||
// Specialize this function to parse each of the parameters that define an
|
||||
// ImageType. By default it assumes this is an enum type.
|
||||
template <typename ValTy>
|
||||
|
@ -567,7 +637,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
|
|||
return parseRuntimeArrayType(*this, parser);
|
||||
if (keyword == "struct")
|
||||
return parseStructType(*this, parser);
|
||||
|
||||
if (keyword == "matrix")
|
||||
return parseMatrixType(*this, parser);
|
||||
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
|
||||
return Type();
|
||||
}
|
||||
|
@ -635,6 +706,11 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
|
|||
os << ">";
|
||||
}
|
||||
|
||||
static void print(MatrixType type, DialectAsmPrinter &os) {
|
||||
os << "matrix<" << type.getNumElements() << " x " << type.getElementType();
|
||||
os << ">";
|
||||
}
|
||||
|
||||
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
switch (type.getKind()) {
|
||||
case TypeKind::Array:
|
||||
|
@ -655,6 +731,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
case TypeKind::Struct:
|
||||
print(type.cast<StructType>(), os);
|
||||
return;
|
||||
case TypeKind::Matrix:
|
||||
print(type.cast<MatrixType>(), os);
|
||||
return;
|
||||
default:
|
||||
llvm_unreachable("unhandled SPIR-V type");
|
||||
}
|
||||
|
|
|
@ -159,6 +159,7 @@ bool CompositeType::classof(Type type) {
|
|||
switch (type.getKind()) {
|
||||
case TypeKind::Array:
|
||||
case TypeKind::CooperativeMatrix:
|
||||
case TypeKind::Matrix:
|
||||
case TypeKind::RuntimeArray:
|
||||
case TypeKind::Struct:
|
||||
return true;
|
||||
|
@ -180,6 +181,8 @@ Type CompositeType::getElementType(unsigned index) const {
|
|||
return cast<ArrayType>().getElementType();
|
||||
case spirv::TypeKind::CooperativeMatrix:
|
||||
return cast<CooperativeMatrixNVType>().getElementType();
|
||||
case spirv::TypeKind::Matrix:
|
||||
return cast<MatrixType>().getElementType();
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
return cast<RuntimeArrayType>().getElementType();
|
||||
case spirv::TypeKind::Struct:
|
||||
|
@ -198,6 +201,8 @@ unsigned CompositeType::getNumElements() const {
|
|||
case spirv::TypeKind::CooperativeMatrix:
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::CooperativeMatrix type");
|
||||
case spirv::TypeKind::Matrix:
|
||||
return cast<MatrixType>().getNumElements();
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::RuntimeArray type");
|
||||
|
@ -230,6 +235,9 @@ void CompositeType::getExtensions(
|
|||
case spirv::TypeKind::CooperativeMatrix:
|
||||
cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
|
||||
break;
|
||||
case spirv::TypeKind::Matrix:
|
||||
cast<MatrixType>().getExtensions(extensions, storage);
|
||||
break;
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
cast<RuntimeArrayType>().getExtensions(extensions, storage);
|
||||
break;
|
||||
|
@ -255,6 +263,9 @@ void CompositeType::getCapabilities(
|
|||
case spirv::TypeKind::CooperativeMatrix:
|
||||
cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
|
||||
break;
|
||||
case spirv::TypeKind::Matrix:
|
||||
cast<MatrixType>().getCapabilities(capabilities, storage);
|
||||
break;
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
|
||||
break;
|
||||
|
@ -823,10 +834,12 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
|||
scalarType.getExtensions(extensions, storage);
|
||||
} else if (auto compositeType = dyn_cast<CompositeType>()) {
|
||||
compositeType.getExtensions(extensions, storage);
|
||||
} else if (auto ptrType = dyn_cast<PointerType>()) {
|
||||
ptrType.getExtensions(extensions, storage);
|
||||
} else if (auto imageType = dyn_cast<ImageType>()) {
|
||||
imageType.getExtensions(extensions, storage);
|
||||
} else if (auto matrixType = dyn_cast<MatrixType>()) {
|
||||
matrixType.getExtensions(extensions, storage);
|
||||
} else if (auto ptrType = dyn_cast<PointerType>()) {
|
||||
ptrType.getExtensions(extensions, storage);
|
||||
} else {
|
||||
llvm_unreachable("invalid SPIR-V Type to getExtensions");
|
||||
}
|
||||
|
@ -839,10 +852,12 @@ void SPIRVType::getCapabilities(
|
|||
scalarType.getCapabilities(capabilities, storage);
|
||||
} else if (auto compositeType = dyn_cast<CompositeType>()) {
|
||||
compositeType.getCapabilities(capabilities, storage);
|
||||
} else if (auto ptrType = dyn_cast<PointerType>()) {
|
||||
ptrType.getCapabilities(capabilities, storage);
|
||||
} else if (auto imageType = dyn_cast<ImageType>()) {
|
||||
imageType.getCapabilities(capabilities, storage);
|
||||
} else if (auto matrixType = dyn_cast<MatrixType>()) {
|
||||
matrixType.getCapabilities(capabilities, storage);
|
||||
} else if (auto ptrType = dyn_cast<PointerType>()) {
|
||||
ptrType.getCapabilities(capabilities, storage);
|
||||
} else {
|
||||
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
|
||||
}
|
||||
|
@ -1000,3 +1015,89 @@ void StructType::getCapabilities(
|
|||
for (Type elementType : getElementTypes())
|
||||
elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct spirv::detail::MatrixTypeStorage : public TypeStorage {
|
||||
MatrixTypeStorage(Type columnType, uint32_t columnCount)
|
||||
: TypeStorage(), columnType(columnType), columnCount(columnCount) {}
|
||||
|
||||
using KeyTy = std::tuple<Type, uint32_t>;
|
||||
|
||||
static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
return new (allocator.allocate<MatrixTypeStorage>())
|
||||
MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(columnType, columnCount);
|
||||
}
|
||||
|
||||
Type columnType;
|
||||
const uint32_t columnCount;
|
||||
};
|
||||
|
||||
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
|
||||
return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
|
||||
columnCount);
|
||||
}
|
||||
|
||||
MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
|
||||
Location location) {
|
||||
return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
|
||||
}
|
||||
|
||||
LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
|
||||
Type columnType,
|
||||
uint32_t columnCount) {
|
||||
if (columnCount < 2 || columnCount > 4)
|
||||
return emitError(loc, "matrix can have 2, 3, or 4 columns only");
|
||||
|
||||
if (!isValidColumnType(columnType))
|
||||
return emitError(loc, "matrix columns must be vectors of floats");
|
||||
|
||||
/// The underlying vectors (columns) must be of size 2, 3, or 4
|
||||
ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
|
||||
if (columnShape.size() != 1)
|
||||
return emitError(loc, "matrix columns must be 1D vectors");
|
||||
|
||||
if (columnShape[0] < 2 || columnShape[0] > 4)
|
||||
return emitError(loc, "matrix columns must be of size 2, 3, or 4");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Returns true if the matrix elements are vectors of float elements
|
||||
bool MatrixType::isValidColumnType(Type columnType) {
|
||||
if (auto vectorType = columnType.dyn_cast<VectorType>()) {
|
||||
if (vectorType.getElementType().isa<FloatType>())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Type MatrixType::getElementType() const { return getImpl()->columnType; }
|
||||
|
||||
unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; }
|
||||
|
||||
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<StorageClass> storage) {
|
||||
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
|
||||
}
|
||||
|
||||
void MatrixType::getCapabilities(
|
||||
SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
Optional<StorageClass> storage) {
|
||||
{
|
||||
static const Capability caps[] = {Capability::Matrix};
|
||||
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
|
||||
capabilities.push_back(ref);
|
||||
}
|
||||
// Add any capabilities associated with the underlying vectors (i.e., columns)
|
||||
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
|
||||
}
|
||||
|
|
|
@ -225,6 +225,8 @@ private:
|
|||
|
||||
LogicalResult processStructType(ArrayRef<uint32_t> operands);
|
||||
|
||||
LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Constant
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -1170,6 +1172,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
|||
return processRuntimeArrayType(operands);
|
||||
case spirv::Opcode::OpTypeStruct:
|
||||
return processStructType(operands);
|
||||
case spirv::Opcode::OpTypeMatrix:
|
||||
return processMatrixType(operands);
|
||||
default:
|
||||
return emitError(unknownLoc, "unhandled type instruction");
|
||||
}
|
||||
|
@ -1333,6 +1337,25 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() != 3) {
|
||||
// Three operands are needed: result_id, column_type, and column_count
|
||||
return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
|
||||
" (result_id, column_type, and column_count)");
|
||||
}
|
||||
// Matrix columns must be of vector type
|
||||
Type elementTy = getType(operands[1]);
|
||||
if (!elementTy) {
|
||||
return emitError(unknownLoc,
|
||||
"OpTypeMatrix references undefined column type.")
|
||||
<< operands[1];
|
||||
}
|
||||
|
||||
uint32_t colsCount = operands[2];
|
||||
typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2238,6 +2261,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
|||
case spirv::Opcode::OpTypeInt:
|
||||
case spirv::Opcode::OpTypeFloat:
|
||||
case spirv::Opcode::OpTypeVector:
|
||||
case spirv::Opcode::OpTypeMatrix:
|
||||
case spirv::Opcode::OpTypeArray:
|
||||
case spirv::Opcode::OpTypeFunction:
|
||||
case spirv::Opcode::OpTypeRuntimeArray:
|
||||
|
|
|
@ -1111,6 +1111,17 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
|
|||
return success();
|
||||
}
|
||||
|
||||
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
|
||||
uint32_t elementTypeID = 0;
|
||||
if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
typeEnum = spirv::Opcode::OpTypeMatrix;
|
||||
operands.push_back(elementTypeID);
|
||||
operands.push_back(matrixType.getNumElements());
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(ravishankarm) : Handle other types.
|
||||
return emitError(loc, "unhandled type in serialization: ") << type;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
spv.func @matrix_type(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>, %arg1 : i32) "None" {
|
||||
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
|
||||
%2 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
|
||||
spv.globalVariable @var0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
|
||||
|
||||
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<2 x vector<3xf32>>, StorageBuffer>
|
||||
spv.globalVariable @var1 : !spv.ptr<!spv.matrix<2 x vector<3xf32>>, StorageBuffer>
|
||||
|
||||
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<4 x vector<4xf16>>, StorageBuffer>
|
||||
spv.globalVariable @var2 : !spv.ptr<!spv.matrix<4 x vector<4xf16>>, StorageBuffer>
|
||||
}
|
|
@ -347,3 +347,87 @@ func @missing_scope(!spv.coopmatrix<8x16xi32>) -> ()
|
|||
// expected-error @+1 {{expected rows and columns size}}
|
||||
func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Matrix
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CHECK: func @matrix_type(!spv.matrix<2 x vector<2xf16>>)
|
||||
func @matrix_type(!spv.matrix<2 x vector<2xf16>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @matrix_type(!spv.matrix<3 x vector<3xf32>>)
|
||||
func @matrix_type(!spv.matrix<3 x vector<3xf32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @matrix_type(!spv.matrix<4 x vector<4xf16>>)
|
||||
func @matrix_type(!spv.matrix<4 x vector<4xf16>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
|
||||
func @matrix_invalid_size(!spv.matrix<5 x vector<3xf32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
|
||||
func @matrix_invalid_size(!spv.matrix<1 x vector<3xf32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 5}}
|
||||
func @matrix_invalid_columns_size(!spv.matrix<3 x vector<5xf32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 1}}
|
||||
func @matrix_invalid_columns_size(!spv.matrix<3 x vector<1xf32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected '<'}}
|
||||
func @matrix_invalid_format(!spv.matrix 3 x vector<3xf32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
|
||||
func @matrix_invalid_format(!spv.matrix< 3 x vector<3xf32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected 'x' in dimension list}}
|
||||
func @matrix_invalid_format(!spv.matrix<2 vector<3xi32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{matrix must be composed using vector type, got 'i32'}}
|
||||
func @matrix_invalid_type(!spv.matrix< 3 x i32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{matrix must be composed using vector type, got '!spv.array<16 x f32>'}}
|
||||
func @matrix_invalid_type(!spv.matrix< 3 x !spv.array<16 x f32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{matrix must be composed using vector type, got '!spv.rtarray<i32>'}}
|
||||
func @matrix_invalid_type(!spv.matrix< 3 x !spv.rtarray<i32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{matrix columns' elements must be of Float type, got 'i32'}}
|
||||
func @matrix_invalid_type(!spv.matrix<2 x vector<3xi32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected single unsigned integer for number of columns}}
|
||||
func @matrix_size_type(!spv.matrix< x vector<3xi32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected single unsigned integer for number of columns}}
|
||||
func @matrix_size_type(!spv.matrix<2.0 x vector<3xi32>>) -> ()
|
||||
|
||||
// -----
|
Loading…
Reference in New Issue