forked from OSchip/llvm-project
[mlir][spirv] Add MatrixTimesMatrix operation
Add MatrixTimesMatrix operation to SPIRV Dialect and add NoSideEffect trait to Matrix ops. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D82671
This commit is contained in:
parent
065fc1eafe
commit
34c4852015
|
@ -3166,6 +3166,7 @@ def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
|
|||
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
|
||||
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
|
||||
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
|
||||
def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
|
||||
def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>;
|
||||
def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
|
||||
def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>;
|
||||
|
@ -3273,38 +3274,38 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
|
||||
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
|
||||
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
|
||||
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
|
||||
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
|
||||
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
|
||||
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
|
||||
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
|
||||
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
|
||||
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
|
||||
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
||||
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
|
||||
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
|
||||
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
|
||||
SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
|
||||
SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
|
||||
SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
|
||||
SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
|
||||
SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
|
||||
SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
|
||||
SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
|
||||
SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
|
||||
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
|
||||
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
|
||||
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
|
||||
SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
|
||||
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
|
||||
SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
|
||||
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
|
||||
SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
|
||||
SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
|
||||
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
|
||||
SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
|
||||
SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
|
||||
SPV_OC_OpCooperativeMatrixLengthNV
|
||||
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual,
|
||||
SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
|
||||
SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
|
||||
SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
|
||||
SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
|
||||
SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
|
||||
SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
|
||||
SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
|
||||
SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
|
||||
SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
|
||||
SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
|
||||
SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
|
||||
SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
|
||||
SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
|
||||
SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
|
||||
SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak,
|
||||
SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
|
||||
SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
|
||||
SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
|
||||
SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
|
||||
SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
|
||||
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
|
||||
SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed,
|
||||
SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
|
||||
SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
|
||||
SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
|
||||
SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin,
|
||||
SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax,
|
||||
SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax,
|
||||
SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV,
|
||||
SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
|
||||
SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV
|
||||
]>;
|
||||
|
||||
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
|
||||
|
|
|
@ -12,10 +12,65 @@
|
|||
|
||||
#ifndef SPIRV_MATRIX_OPS
|
||||
#define SPIRV_MATRIX_OPS
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
|
||||
def SPV_MatrixTimesMatrixOp : SPV_Op<"MatrixTimesMatrix", [NoSideEffect]> {
|
||||
let summary = "Linear-algebraic multiply of LeftMatrix X RightMatrix.";
|
||||
|
||||
let description = [{
|
||||
Result Type must be an OpTypeMatrix whose Column Type is a vector of
|
||||
floating-point type.
|
||||
|
||||
LeftMatrix must be a matrix whose Column Type is the same as the Column
|
||||
Type in Result Type.
|
||||
|
||||
RightMatrix must be a matrix with the same Component Type as the
|
||||
Component Type in Result Type. Its number of columns must equal the
|
||||
number of columns in Result Type. Its columns must have the same number
|
||||
of components as the number of columns in LeftMatrix.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
matrix-times-matrix-op ::= ssa-id `=` `spv.MatrixTimesMatrix` ssa-use,
|
||||
ssa-use `:` matrix-type `,` matrix-type `->` matrix-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%0 = spv.MatrixTimesMatrix %matrix_1, %matrix_2 :
|
||||
!spv.matrix<4 x vector<3xf32>>, !spv.matrix<3 x vector<4xf32>> ->
|
||||
!spv.matrix<4 x vector<4xf32>>
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Matrix]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_AnyMatrix:$leftmatrix,
|
||||
SPV_AnyMatrix:$rightmatrix
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_AnyMatrix:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` type($leftmatrix) `,` type($rightmatrix) `->` type($result)
|
||||
}];
|
||||
let verifier = [{ return verifyMatrixTimesMatrix(*this); }];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", [NoSideEffect]> {
|
||||
let summary = "Scale a floating-point matrix.";
|
||||
|
||||
let description = [{
|
||||
|
@ -79,7 +134,7 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_TransposeOp : SPV_Op<"Transpose", []> {
|
||||
def SPV_TransposeOp : SPV_Op<"Transpose", [NoSideEffect]> {
|
||||
let summary = "Transpose a matrix.";
|
||||
|
||||
let description = [{
|
||||
|
|
|
@ -410,13 +410,23 @@ public:
|
|||
Type columnType,
|
||||
uint32_t columnCount);
|
||||
|
||||
/// Returns true if the matrix elements are vectors of float elements
|
||||
/// Returns true if the matrix elements are vectors of float elements.
|
||||
static bool isValidColumnType(Type columnType);
|
||||
|
||||
Type getElementType() const;
|
||||
Type getColumnType() const;
|
||||
|
||||
/// Returns the number of rows.
|
||||
unsigned getNumRows() const;
|
||||
|
||||
/// Returns the number of columns.
|
||||
unsigned getNumColumns() const;
|
||||
|
||||
/// Returns total number of elements (rows*columns).
|
||||
unsigned getNumElements() const;
|
||||
|
||||
/// Returns the elements' type (i.e, single element type).
|
||||
Type getElementType() const;
|
||||
|
||||
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<spirv::StorageClass> storage = llvm::None);
|
||||
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
|
|
|
@ -723,7 +723,7 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
|
|||
}
|
||||
|
||||
static void print(MatrixType type, DialectAsmPrinter &os) {
|
||||
os << "matrix<" << type.getNumElements() << " x " << type.getElementType();
|
||||
os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
|
||||
os << ">";
|
||||
}
|
||||
|
||||
|
|
|
@ -2779,37 +2779,30 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
|
|||
// auto-generated verify method.
|
||||
|
||||
auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
|
||||
// Check that the scalar type is the same as the matrix components type.
|
||||
if (auto inputMatrixColumns =
|
||||
inputMatrix.getElementType().dyn_cast<VectorType>()) {
|
||||
if (op.scalar().getType() != inputMatrixColumns.getElementType())
|
||||
return op.emitError("input matrix components' type and scaling "
|
||||
"value must have the same type");
|
||||
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
|
||||
|
||||
// Note that the next three checks could be done using the AllTypesMatch
|
||||
// trait in the Op definition file but it generates a vague error message.
|
||||
// Check that the scalar type is the same as the matrix element type.
|
||||
if (op.scalar().getType() != inputMatrix.getElementType())
|
||||
return op.emitError("input matrix components' type and scaling value must "
|
||||
"have the same type");
|
||||
|
||||
// Check that the input and result matrices have the same size
|
||||
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
|
||||
if (inputMatrix.getNumElements() != resultMatrix.getNumElements())
|
||||
return op.emitError("input and result matrices must have "
|
||||
"the same number of columns");
|
||||
// Note that the next three checks could be done using the AllTypesMatch
|
||||
// trait in the Op definition file but it generates a vague error message.
|
||||
|
||||
if (auto resultMatrixColumns =
|
||||
resultMatrix.getElementType().dyn_cast<VectorType>()) {
|
||||
// Check that the input and result matrices' columns have the same type
|
||||
if (inputMatrixColumns.getElementType() !=
|
||||
resultMatrixColumns.getElementType())
|
||||
return op.emitError("input and result matrices' columns must "
|
||||
"have the same component type");
|
||||
// Check that the input and result matrices have the same columns' count
|
||||
if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
|
||||
return op.emitError("input and result matrices must have the same "
|
||||
"number of columns");
|
||||
|
||||
// Check that the input and result matrices' columns have the same size
|
||||
if (inputMatrixColumns.getNumElements() !=
|
||||
resultMatrixColumns.getNumElements())
|
||||
return op.emitError("input and result matrices' columns must "
|
||||
"have the same size");
|
||||
}
|
||||
}
|
||||
// Check that the input and result matrices' have the same rows count
|
||||
if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
|
||||
return op.emitError("input and result matrices' columns must have "
|
||||
"the same size");
|
||||
|
||||
// Check that the input and result matrices' have the same component type
|
||||
if (inputMatrix.getElementType() != resultMatrix.getElementType())
|
||||
return op.emitError("input and result matrices' columns must have "
|
||||
"the same component type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -2902,24 +2895,56 @@ static LogicalResult verifyTranspose(spirv::TransposeOp op) {
|
|||
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
|
||||
|
||||
// Verify that the input and output matrices have correct shapes.
|
||||
if (auto inputMatrixColumns =
|
||||
inputMatrix.getElementType().dyn_cast<VectorType>()) {
|
||||
if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
|
||||
return op.emitError("input matrix rows count must be equal to "
|
||||
"output matrix columns count");
|
||||
if (auto resultMatrixColumns =
|
||||
resultMatrix.getElementType().dyn_cast<VectorType>()) {
|
||||
if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
|
||||
return op.emitError("input matrix columns count must be equal "
|
||||
"to output matrix rows count");
|
||||
if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
|
||||
return op.emitError("input matrix rows count must be equal to "
|
||||
"output matrix columns count");
|
||||
|
||||
if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
|
||||
return op.emitError("input matrix columns count must be equal to "
|
||||
"output matrix rows count");
|
||||
|
||||
// Verify that the input and output matrices have the same component type
|
||||
if (inputMatrix.getElementType() != resultMatrix.getElementType())
|
||||
return op.emitError("input and output matrices must have the same "
|
||||
"component type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.MatrixTimesMatrix
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
|
||||
auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
|
||||
auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
|
||||
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
|
||||
|
||||
// left matrix columns' count and right matrix rows' count must be equal
|
||||
if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
|
||||
return op.emitError("left matrix columns' count must be equal to "
|
||||
"the right matrix rows' count");
|
||||
|
||||
// right and result matrices columns' count must be the same
|
||||
if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
|
||||
return op.emitError(
|
||||
"right and result matrices must have equal columns' count");
|
||||
|
||||
// right and result matrices component type must be the same
|
||||
if (rightMatrix.getElementType() != resultMatrix.getElementType())
|
||||
return op.emitError("right and result matrices' component type must"
|
||||
" be the same");
|
||||
|
||||
// left and result matrices component type must be the same
|
||||
if (leftMatrix.getElementType() != resultMatrix.getElementType())
|
||||
return op.emitError("left and result matrices' component type"
|
||||
" must be the same");
|
||||
|
||||
// left and result matrices rows count must be the same
|
||||
if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
|
||||
return op.emitError("left and result matrices must have equal rows'"
|
||||
" count");
|
||||
|
||||
// Verify that the input and output matrices have the same component type
|
||||
if (inputMatrixColumns.getElementType() !=
|
||||
resultMatrixColumns.getElementType())
|
||||
return op.emitError("input and output matrices must have the "
|
||||
"same component type");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -182,7 +182,7 @@ Type CompositeType::getElementType(unsigned index) const {
|
|||
case spirv::TypeKind::CooperativeMatrix:
|
||||
return cast<CooperativeMatrixNVType>().getElementType();
|
||||
case spirv::TypeKind::Matrix:
|
||||
return cast<MatrixType>().getElementType();
|
||||
return cast<MatrixType>().getColumnType();
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
return cast<RuntimeArrayType>().getElementType();
|
||||
case spirv::TypeKind::Struct:
|
||||
|
@ -202,7 +202,7 @@ unsigned CompositeType::getNumElements() const {
|
|||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::CooperativeMatrix type");
|
||||
case spirv::TypeKind::Matrix:
|
||||
return cast<MatrixType>().getNumElements();
|
||||
return cast<MatrixType>().getNumColumns();
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::RuntimeArray type");
|
||||
|
@ -1086,13 +1086,25 @@ bool MatrixType::isValidColumnType(Type columnType) {
|
|||
return false;
|
||||
}
|
||||
|
||||
Type MatrixType::getElementType() const { return getImpl()->columnType; }
|
||||
Type MatrixType::getColumnType() const { return getImpl()->columnType; }
|
||||
|
||||
unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; }
|
||||
Type MatrixType::getElementType() const {
|
||||
return getImpl()->columnType.cast<VectorType>().getElementType();
|
||||
}
|
||||
|
||||
unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
|
||||
|
||||
unsigned MatrixType::getNumRows() const {
|
||||
return getImpl()->columnType.cast<VectorType>().getShape()[0];
|
||||
}
|
||||
|
||||
unsigned MatrixType::getNumElements() const {
|
||||
return (getImpl()->columnCount) * getNumRows();
|
||||
}
|
||||
|
||||
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
Optional<StorageClass> storage) {
|
||||
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
|
||||
getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
|
||||
}
|
||||
|
||||
void MatrixType::getCapabilities(
|
||||
|
@ -1104,5 +1116,5 @@ void MatrixType::getCapabilities(
|
|||
capabilities.push_back(ref);
|
||||
}
|
||||
// Add any capabilities associated with the underlying vectors (i.e., columns)
|
||||
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
|
||||
getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
|
||||
}
|
||||
|
|
|
@ -1127,12 +1127,12 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
|
|||
|
||||
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
|
||||
uint32_t elementTypeID = 0;
|
||||
if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) {
|
||||
if (failed(processType(loc, matrixType.getColumnType(), elementTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
typeEnum = spirv::Opcode::OpTypeMatrix;
|
||||
operands.push_back(elementTypeID);
|
||||
operands.push_back(matrixType.getNumElements());
|
||||
operands.push_back(matrixType.getNumColumns());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,20 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
|||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @matrix_times_matrix_1
|
||||
spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{
|
||||
// CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @matrix_times_matrix_2
|
||||
spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{
|
||||
// CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -21,6 +21,20 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
|||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @matrix_times_matrix_1
|
||||
spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{
|
||||
// CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @matrix_times_matrix_2
|
||||
spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{
|
||||
// CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
|
||||
spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -74,3 +88,39 @@ func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
|
|||
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){
|
||||
// expected-error @+1 {{right and result matrices must have equal columns' count}}
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<3 x vector<2xf32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){
|
||||
// expected-error @+1 {{left and result matrices must have equal rows' count}}
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<3xf32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<2xf32>>){
|
||||
// expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<2xf32>> -> !spv.matrix<2 x vector<2xf32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : !spv.matrix<3x vector<3xf32>>){
|
||||
// expected-error @+1 {{right and result matrices' component type must be the same}}
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf64>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spv.matrix<3 x vector<3xf64>>, %arg1 : !spv.matrix<3x vector<3xf32>>){
|
||||
// expected-error @+1 {{left and result matrices' component type must be the same}}
|
||||
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf64>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue