[mlir][spirv] Add support for matrix type

This commit adds basic matrix type support to the SPIR-V dialect
including type definition, IR assembly, parsing, printing, and
(de)serialization.

Differential Revision: https://reviews.llvm.org/D80594
This commit is contained in:
HazemAbdelhafez 2020-06-02 16:22:38 -04:00 committed by Lei Zhang
parent 971459c3ef
commit 915e55c910
8 changed files with 374 additions and 17 deletions

View File

@ -3109,6 +3109,7 @@ def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>;
def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>;
@ -3250,15 +3251,15 @@ def SPV_OpcodeAttr :
SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix,
SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,

View File

@ -13,6 +13,8 @@
#ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
#define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
@ -56,9 +58,11 @@ namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct StructTypeStorage;
} // namespace detail
namespace TypeKind {
@ -66,6 +70,7 @@ enum Kind {
Array = Type::FIRST_SPIRV_TYPE,
CooperativeMatrix,
Image,
Matrix,
Pointer,
RuntimeArray,
Struct,
@ -366,6 +371,36 @@ public:
Optional<spirv::StorageClass> storage = llvm::None);
};
// SPIR-V matrix type
class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
detail::MatrixTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == TypeKind::Matrix; }
static MatrixType get(Type columnType, uint32_t columnCount);
static MatrixType getChecked(Type columnType, uint32_t columnCount,
Location location);
static LogicalResult verifyConstructionInvariants(Location loc,
Type columnType,
uint32_t columnCount);
/// Returns true if the matrix elements are vectors of float elements
static bool isValidColumnType(Type columnType);
Type getElementType() const;
unsigned getNumElements() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<spirv::StorageClass> storage = llvm::None);
};
} // end namespace spirv
} // end namespace mlir

View File

@ -116,8 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, PointerType,
RuntimeArrayType, StructType>();
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
PointerType, RuntimeArrayType, StructType>();
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
@ -197,6 +197,42 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
return type;
}
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
Type type;
llvm::SMLoc typeLoc = parser.getCurrentLocation();
if (parser.parseType(type))
return Type();
if (auto t = type.dyn_cast<VectorType>()) {
if (t.getRank() != 1) {
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
}
if (t.getNumElements() > 4 || t.getNumElements() < 2) {
parser.emitError(typeLoc,
"matrix columns size has to be less than or equal "
"to 4 and greater than or equal 2, but found ")
<< t.getNumElements();
return Type();
}
if (!t.getElementType().isa<FloatType>()) {
parser.emitError(typeLoc, "matrix columns' elements must be of "
"Float type, got ")
<< t.getElementType();
return Type();
}
} else {
parser.emitError(typeLoc, "matrix must be composed using vector "
"type, got ")
<< type;
return Type();
}
return type;
}
/// Parses an optional `, stride = N` assembly segment. If no parsing failure
/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
/// missing.
@ -279,7 +315,7 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return Type();
if (dims.size() != 2) {
parser.emitError(countLoc, "expected rows and columns size.");
parser.emitError(countLoc, "expected rows and columns size");
return Type();
}
@ -350,6 +386,40 @@ static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
return RuntimeArrayType::get(elementType, stride);
}
// matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
static Type parseMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
SmallVector<int64_t, 1> countDims;
llvm::SMLoc countLoc = parser.getCurrentLocation();
if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
return Type();
if (countDims.size() != 1) {
parser.emitError(countLoc, "expected single unsigned "
"integer for number of columns");
return Type();
}
int64_t columnCount = countDims[0];
// According to the specification, Matrices can have 2, 3, or 4 columns
if (columnCount < 2 || columnCount > 4) {
parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
"columns");
return Type();
}
Type columnType = parseAndVerifyMatrixType(dialect, parser);
if (!columnType)
return Type();
if (parser.parseGreater())
return Type();
return MatrixType::get(columnType, columnCount);
}
// Specialize this function to parse each of the parameters that define an
// ImageType. By default it assumes this is an enum type.
template <typename ValTy>
@ -567,7 +637,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseRuntimeArrayType(*this, parser);
if (keyword == "struct")
return parseStructType(*this, parser);
if (keyword == "matrix")
return parseMatrixType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
@ -635,6 +706,11 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
os << ">";
}
static void print(MatrixType type, DialectAsmPrinter &os) {
os << "matrix<" << type.getNumElements() << " x " << type.getElementType();
os << ">";
}
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case TypeKind::Array:
@ -655,6 +731,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
case TypeKind::Struct:
print(type.cast<StructType>(), os);
return;
case TypeKind::Matrix:
print(type.cast<MatrixType>(), os);
return;
default:
llvm_unreachable("unhandled SPIR-V type");
}

View File

@ -159,6 +159,7 @@ bool CompositeType::classof(Type type) {
switch (type.getKind()) {
case TypeKind::Array:
case TypeKind::CooperativeMatrix:
case TypeKind::Matrix:
case TypeKind::RuntimeArray:
case TypeKind::Struct:
return true;
@ -180,6 +181,8 @@ Type CompositeType::getElementType(unsigned index) const {
return cast<ArrayType>().getElementType();
case spirv::TypeKind::CooperativeMatrix:
return cast<CooperativeMatrixNVType>().getElementType();
case spirv::TypeKind::Matrix:
return cast<MatrixType>().getElementType();
case spirv::TypeKind::RuntimeArray:
return cast<RuntimeArrayType>().getElementType();
case spirv::TypeKind::Struct:
@ -198,6 +201,8 @@ unsigned CompositeType::getNumElements() const {
case spirv::TypeKind::CooperativeMatrix:
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
case spirv::TypeKind::Matrix:
return cast<MatrixType>().getNumElements();
case spirv::TypeKind::RuntimeArray:
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
@ -230,6 +235,9 @@ void CompositeType::getExtensions(
case spirv::TypeKind::CooperativeMatrix:
cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
break;
case spirv::TypeKind::Matrix:
cast<MatrixType>().getExtensions(extensions, storage);
break;
case spirv::TypeKind::RuntimeArray:
cast<RuntimeArrayType>().getExtensions(extensions, storage);
break;
@ -255,6 +263,9 @@ void CompositeType::getCapabilities(
case spirv::TypeKind::CooperativeMatrix:
cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
break;
case spirv::TypeKind::Matrix:
cast<MatrixType>().getCapabilities(capabilities, storage);
break;
case spirv::TypeKind::RuntimeArray:
cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
break;
@ -823,10 +834,12 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
scalarType.getExtensions(extensions, storage);
} else if (auto compositeType = dyn_cast<CompositeType>()) {
compositeType.getExtensions(extensions, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
ptrType.getExtensions(extensions, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
imageType.getExtensions(extensions, storage);
} else if (auto matrixType = dyn_cast<MatrixType>()) {
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
ptrType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
}
@ -839,10 +852,12 @@ void SPIRVType::getCapabilities(
scalarType.getCapabilities(capabilities, storage);
} else if (auto compositeType = dyn_cast<CompositeType>()) {
compositeType.getCapabilities(capabilities, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
ptrType.getCapabilities(capabilities, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
imageType.getCapabilities(capabilities, storage);
} else if (auto matrixType = dyn_cast<MatrixType>()) {
matrixType.getCapabilities(capabilities, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
ptrType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
}
@ -1000,3 +1015,89 @@ void StructType::getCapabilities(
for (Type elementType : getElementTypes())
elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
struct spirv::detail::MatrixTypeStorage : public TypeStorage {
MatrixTypeStorage(Type columnType, uint32_t columnCount)
: TypeStorage(), columnType(columnType), columnCount(columnCount) {}
using KeyTy = std::tuple<Type, uint32_t>;
static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
// Initialize the memory using placement new.
return new (allocator.allocate<MatrixTypeStorage>())
MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(columnType, columnCount);
}
Type columnType;
const uint32_t columnCount;
};
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
columnCount);
}
MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
Location location) {
return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
}
LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
Type columnType,
uint32_t columnCount) {
if (columnCount < 2 || columnCount > 4)
return emitError(loc, "matrix can have 2, 3, or 4 columns only");
if (!isValidColumnType(columnType))
return emitError(loc, "matrix columns must be vectors of floats");
/// The underlying vectors (columns) must be of size 2, 3, or 4
ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
if (columnShape.size() != 1)
return emitError(loc, "matrix columns must be 1D vectors");
if (columnShape[0] < 2 || columnShape[0] > 4)
return emitError(loc, "matrix columns must be of size 2, 3, or 4");
return success();
}
/// Returns true if the matrix elements are vectors of float elements
bool MatrixType::isValidColumnType(Type columnType) {
if (auto vectorType = columnType.dyn_cast<VectorType>()) {
if (vectorType.getElementType().isa<FloatType>())
return true;
}
return false;
}
Type MatrixType::getElementType() const { return getImpl()->columnType; }
unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; }
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
}
void MatrixType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
{
static const Capability caps[] = {Capability::Matrix};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
}
// Add any capabilities associated with the underlying vectors (i.e., columns)
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
}

View File

@ -225,6 +225,8 @@ private:
LogicalResult processStructType(ArrayRef<uint32_t> operands);
LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
// Constant
//===--------------------------------------------------------------------===//
@ -1170,6 +1172,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
return processRuntimeArrayType(operands);
case spirv::Opcode::OpTypeStruct:
return processStructType(operands);
case spirv::Opcode::OpTypeMatrix:
return processMatrixType(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@ -1333,6 +1337,25 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
if (operands.size() != 3) {
// Three operands are needed: result_id, column_type, and column_count
return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
" (result_id, column_type, and column_count)");
}
// Matrix columns must be of vector type
Type elementTy = getType(operands[1]);
if (!elementTy) {
return emitError(unknownLoc,
"OpTypeMatrix references undefined column type.")
<< operands[1];
}
uint32_t colsCount = operands[2];
typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
return success();
}
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
@ -2238,6 +2261,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
case spirv::Opcode::OpTypeInt:
case spirv::Opcode::OpTypeFloat:
case spirv::Opcode::OpTypeVector:
case spirv::Opcode::OpTypeMatrix:
case spirv::Opcode::OpTypeArray:
case spirv::Opcode::OpTypeFunction:
case spirv::Opcode::OpTypeRuntimeArray:

View File

@ -1111,6 +1111,17 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return success();
}
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
uint32_t elementTypeID = 0;
if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeMatrix;
operands.push_back(elementTypeID);
operands.push_back(matrixType.getNumElements());
return success();
}
// TODO(ravishankarm) : Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}

View File

@ -0,0 +1,22 @@
// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.func @matrix_type(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>, %arg1 : i32) "None" {
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
%2 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
spv.Return
}
}
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
spv.globalVariable @var0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<2 x vector<3xf32>>, StorageBuffer>
spv.globalVariable @var1 : !spv.ptr<!spv.matrix<2 x vector<3xf32>>, StorageBuffer>
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<4 x vector<4xf16>>, StorageBuffer>
spv.globalVariable @var2 : !spv.ptr<!spv.matrix<4 x vector<4xf16>>, StorageBuffer>
}

View File

@ -347,3 +347,87 @@ func @missing_scope(!spv.coopmatrix<8x16xi32>) -> ()
// expected-error @+1 {{expected rows and columns size}}
func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> ()
// -----
//===----------------------------------------------------------------------===//
// Matrix
//===----------------------------------------------------------------------===//
// CHECK: func @matrix_type(!spv.matrix<2 x vector<2xf16>>)
func @matrix_type(!spv.matrix<2 x vector<2xf16>>) -> ()
// -----
// CHECK: func @matrix_type(!spv.matrix<3 x vector<3xf32>>)
func @matrix_type(!spv.matrix<3 x vector<3xf32>>) -> ()
// -----
// CHECK: func @matrix_type(!spv.matrix<4 x vector<4xf16>>)
func @matrix_type(!spv.matrix<4 x vector<4xf16>>) -> ()
// -----
// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
func @matrix_invalid_size(!spv.matrix<5 x vector<3xf32>>) -> ()
// -----
// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
func @matrix_invalid_size(!spv.matrix<1 x vector<3xf32>>) -> ()
// -----
// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 5}}
func @matrix_invalid_columns_size(!spv.matrix<3 x vector<5xf32>>) -> ()
// -----
// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 1}}
func @matrix_invalid_columns_size(!spv.matrix<3 x vector<1xf32>>) -> ()
// -----
// expected-error @+1 {{expected '<'}}
func @matrix_invalid_format(!spv.matrix 3 x vector<3xf32>>) -> ()
// -----
// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
func @matrix_invalid_format(!spv.matrix< 3 x vector<3xf32>) -> ()
// -----
// expected-error @+1 {{expected 'x' in dimension list}}
func @matrix_invalid_format(!spv.matrix<2 vector<3xi32>>) -> ()
// -----
// expected-error @+1 {{matrix must be composed using vector type, got 'i32'}}
func @matrix_invalid_type(!spv.matrix< 3 x i32>) -> ()
// -----
// expected-error @+1 {{matrix must be composed using vector type, got '!spv.array<16 x f32>'}}
func @matrix_invalid_type(!spv.matrix< 3 x !spv.array<16 x f32>>) -> ()
// -----
// expected-error @+1 {{matrix must be composed using vector type, got '!spv.rtarray<i32>'}}
func @matrix_invalid_type(!spv.matrix< 3 x !spv.rtarray<i32>>) -> ()
// -----
// expected-error @+1 {{matrix columns' elements must be of Float type, got 'i32'}}
func @matrix_invalid_type(!spv.matrix<2 x vector<3xi32>>) -> ()
// -----
// expected-error @+1 {{expected single unsigned integer for number of columns}}
func @matrix_size_type(!spv.matrix< x vector<3xi32>>) -> ()
// -----
// expected-error @+1 {{expected single unsigned integer for number of columns}}
func @matrix_size_type(!spv.matrix<2.0 x vector<3xi32>>) -> ()
// -----