[mlir][spirv] Add MatrixTimesMatrix operation

Add MatrixTimesMatrix operation to SPIRV Dialect and add NoSideEffect trait
to Matrix ops.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D82671
This commit is contained in:
HazemAbdelhafez 2020-07-07 21:32:39 -04:00 committed by Lei Zhang
parent 065fc1eafe
commit 34c4852015
9 changed files with 256 additions and 89 deletions

View File

@ -3166,6 +3166,7 @@ def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>;
def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>;
@ -3273,38 +3274,38 @@ def SPV_OpcodeAttr :
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
SPV_OC_OpCooperativeMatrixLengthNV
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual,
SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak,
SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed,
SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin,
SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax,
SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax,
SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV,
SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

View File

@ -12,10 +12,65 @@
#ifndef SPIRV_MATRIX_OPS
#define SPIRV_MATRIX_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
// -----
def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
def SPV_MatrixTimesMatrixOp : SPV_Op<"MatrixTimesMatrix", [NoSideEffect]> {
let summary = "Linear-algebraic multiply of LeftMatrix X RightMatrix.";
let description = [{
Result Type must be an OpTypeMatrix whose Column Type is a vector of
floating-point type.
LeftMatrix must be a matrix whose Column Type is the same as the Column
Type in Result Type.
RightMatrix must be a matrix with the same Component Type as the
Component Type in Result Type. Its number of columns must equal the
number of columns in Result Type. Its columns must have the same number
of components as the number of columns in LeftMatrix.
<!-- End of AutoGen section -->
```
matrix-times-matrix-op ::= ssa-id `=` `spv.MatrixTimesMatrix` ssa-use,
ssa-use `:` matrix-type `,` matrix-type `->` matrix-type
```mlir
#### Example:
```
%0 = spv.MatrixTimesMatrix %matrix_1, %matrix_2 :
!spv.matrix<4 x vector<3xf32>>, !spv.matrix<3 x vector<4xf32>> ->
!spv.matrix<4 x vector<4xf32>>
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Matrix]>
];
let arguments = (ins
SPV_AnyMatrix:$leftmatrix,
SPV_AnyMatrix:$rightmatrix
);
let results = (outs
SPV_AnyMatrix:$result
);
let assemblyFormat = [{
operands attr-dict `:` type($leftmatrix) `,` type($rightmatrix) `->` type($result)
}];
let verifier = [{ return verifyMatrixTimesMatrix(*this); }];
}
// -----
def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", [NoSideEffect]> {
let summary = "Scale a floating-point matrix.";
let description = [{
@ -79,7 +134,7 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
// -----
def SPV_TransposeOp : SPV_Op<"Transpose", []> {
def SPV_TransposeOp : SPV_Op<"Transpose", [NoSideEffect]> {
let summary = "Transpose a matrix.";
let description = [{

View File

@ -410,13 +410,23 @@ public:
Type columnType,
uint32_t columnCount);
/// Returns true if the matrix elements are vectors of float elements
/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
Type getElementType() const;
Type getColumnType() const;
/// Returns the number of rows.
unsigned getNumRows() const;
/// Returns the number of columns.
unsigned getNumColumns() const;
/// Returns total number of elements (rows*columns).
unsigned getNumElements() const;
/// Returns the elements' type (i.e, single element type).
Type getElementType() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,

View File

@ -723,7 +723,7 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
}
static void print(MatrixType type, DialectAsmPrinter &os) {
os << "matrix<" << type.getNumElements() << " x " << type.getElementType();
os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
os << ">";
}

View File

@ -2779,37 +2779,30 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
// auto-generated verify method.
auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
// Check that the scalar type is the same as the matrix components type.
if (auto inputMatrixColumns =
inputMatrix.getElementType().dyn_cast<VectorType>()) {
if (op.scalar().getType() != inputMatrixColumns.getElementType())
return op.emitError("input matrix components' type and scaling "
"value must have the same type");
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
// Note that the next three checks could be done using the AllTypesMatch
// trait in the Op definition file but it generates a vague error message.
// Check that the scalar type is the same as the matrix element type.
if (op.scalar().getType() != inputMatrix.getElementType())
return op.emitError("input matrix components' type and scaling value must "
"have the same type");
// Check that the input and result matrices have the same size
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
if (inputMatrix.getNumElements() != resultMatrix.getNumElements())
return op.emitError("input and result matrices must have "
"the same number of columns");
// Note that the next three checks could be done using the AllTypesMatch
// trait in the Op definition file but it generates a vague error message.
if (auto resultMatrixColumns =
resultMatrix.getElementType().dyn_cast<VectorType>()) {
// Check that the input and result matrices' columns have the same type
if (inputMatrixColumns.getElementType() !=
resultMatrixColumns.getElementType())
return op.emitError("input and result matrices' columns must "
"have the same component type");
// Check that the input and result matrices have the same columns' count
if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
return op.emitError("input and result matrices must have the same "
"number of columns");
// Check that the input and result matrices' columns have the same size
if (inputMatrixColumns.getNumElements() !=
resultMatrixColumns.getNumElements())
return op.emitError("input and result matrices' columns must "
"have the same size");
}
}
// Check that the input and result matrices' have the same rows count
if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
return op.emitError("input and result matrices' columns must have "
"the same size");
// Check that the input and result matrices' have the same component type
if (inputMatrix.getElementType() != resultMatrix.getElementType())
return op.emitError("input and result matrices' columns must have "
"the same component type");
return success();
}
@ -2902,24 +2895,56 @@ static LogicalResult verifyTranspose(spirv::TransposeOp op) {
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
// Verify that the input and output matrices have correct shapes.
if (auto inputMatrixColumns =
inputMatrix.getElementType().dyn_cast<VectorType>()) {
if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
return op.emitError("input matrix rows count must be equal to "
"output matrix columns count");
if (auto resultMatrixColumns =
resultMatrix.getElementType().dyn_cast<VectorType>()) {
if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
return op.emitError("input matrix columns count must be equal "
"to output matrix rows count");
if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
return op.emitError("input matrix rows count must be equal to "
"output matrix columns count");
if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
return op.emitError("input matrix columns count must be equal to "
"output matrix rows count");
// Verify that the input and output matrices have the same component type
if (inputMatrix.getElementType() != resultMatrix.getElementType())
return op.emitError("input and output matrices must have the same "
"component type");
return success();
}
//===----------------------------------------------------------------------===//
// spv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//
static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
// left matrix columns' count and right matrix rows' count must be equal
if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
return op.emitError("left matrix columns' count must be equal to "
"the right matrix rows' count");
// right and result matrices columns' count must be the same
if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
return op.emitError(
"right and result matrices must have equal columns' count");
// right and result matrices component type must be the same
if (rightMatrix.getElementType() != resultMatrix.getElementType())
return op.emitError("right and result matrices' component type must"
" be the same");
// left and result matrices component type must be the same
if (leftMatrix.getElementType() != resultMatrix.getElementType())
return op.emitError("left and result matrices' component type"
" must be the same");
// left and result matrices rows count must be the same
if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
return op.emitError("left and result matrices must have equal rows'"
" count");
// Verify that the input and output matrices have the same component type
if (inputMatrixColumns.getElementType() !=
resultMatrixColumns.getElementType())
return op.emitError("input and output matrices must have the "
"same component type");
}
}
return success();
}

View File

@ -182,7 +182,7 @@ Type CompositeType::getElementType(unsigned index) const {
case spirv::TypeKind::CooperativeMatrix:
return cast<CooperativeMatrixNVType>().getElementType();
case spirv::TypeKind::Matrix:
return cast<MatrixType>().getElementType();
return cast<MatrixType>().getColumnType();
case spirv::TypeKind::RuntimeArray:
return cast<RuntimeArrayType>().getElementType();
case spirv::TypeKind::Struct:
@ -202,7 +202,7 @@ unsigned CompositeType::getNumElements() const {
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
case spirv::TypeKind::Matrix:
return cast<MatrixType>().getNumElements();
return cast<MatrixType>().getNumColumns();
case spirv::TypeKind::RuntimeArray:
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
@ -1086,13 +1086,25 @@ bool MatrixType::isValidColumnType(Type columnType) {
return false;
}
Type MatrixType::getElementType() const { return getImpl()->columnType; }
Type MatrixType::getColumnType() const { return getImpl()->columnType; }
unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; }
Type MatrixType::getElementType() const {
return getImpl()->columnType.cast<VectorType>().getElementType();
}
unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
unsigned MatrixType::getNumRows() const {
return getImpl()->columnType.cast<VectorType>().getShape()[0];
}
unsigned MatrixType::getNumElements() const {
return (getImpl()->columnCount) * getNumRows();
}
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
}
void MatrixType::getCapabilities(
@ -1104,5 +1116,5 @@ void MatrixType::getCapabilities(
capabilities.push_back(ref);
}
// Add any capabilities associated with the underlying vectors (i.e., columns)
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
}

View File

@ -1127,12 +1127,12 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
uint32_t elementTypeID = 0;
if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) {
if (failed(processType(loc, matrixType.getColumnType(), elementTypeID))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeMatrix;
operands.push_back(elementTypeID);
operands.push_back(matrixType.getNumElements());
operands.push_back(matrixType.getNumColumns());
return success();
}

View File

@ -29,6 +29,20 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_times_matrix_1
spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_times_matrix_2
spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{
// CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>>
}
}
// -----

View File

@ -21,6 +21,20 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_times_matrix_1
spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_times_matrix_2
spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{
// CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>>
}
}
// -----
@ -74,3 +88,39 @@ func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
spv.Return
}
// -----
func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){
// expected-error @+1 {{right and result matrices must have equal columns' count}}
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<3 x vector<2xf32>>
}
// -----
func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){
// expected-error @+1 {{left and result matrices must have equal rows' count}}
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<3xf32>>
}
// -----
func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<2xf32>>){
// expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<2xf32>> -> !spv.matrix<2 x vector<2xf32>>
}
// -----
func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : !spv.matrix<3x vector<3xf32>>){
// expected-error @+1 {{right and result matrices' component type must be the same}}
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf64>>
}
// -----
func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spv.matrix<3 x vector<3xf64>>, %arg1 : !spv.matrix<3x vector<3xf32>>){
// expected-error @+1 {{left and result matrices' component type must be the same}}
%result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf64>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
}