forked from OSchip/llvm-project
[mlir][spirv] Add MatrixTimesScalar operation
Summary: - Define the MatrixTimesScalar operation and add roundtrip tests. - Added a new base class for matrix-specific operations to avoid invalid operands type mismatch check. - Created a separate Matrix arithmetic operations td file to add more operations in the future. - Augmented the automatically generated verify method to print more fine-grained error messages. - Made minor Updates to the matrix type tests. Reviewers: antiagainst, rriddle, mravishankar Reviewed By: antiagainst Subscribers: mehdi_amini, jpienaar, shauheen, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, bader, grosul1, frgossen, Kayjukh, jurahul, msifontes Tags: #mlir Differential Revision: https://reviews.llvm.org/D81677
This commit is contained in:
parent
ac20150e29
commit
55d53d4f54
|
@ -2994,10 +2994,12 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
|
||||||
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
|
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
|
||||||
def SPV_IsCooperativeMatrixType :
|
def SPV_IsCooperativeMatrixType :
|
||||||
CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
|
CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
|
||||||
|
def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">;
|
||||||
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
|
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
|
||||||
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
|
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
|
||||||
def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
|
def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
|
||||||
|
|
||||||
|
|
||||||
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
|
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
|
||||||
// for the definition of the following types and type categories.
|
// for the definition of the following types and type categories.
|
||||||
|
|
||||||
|
@ -3018,6 +3020,8 @@ def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
|
||||||
def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
|
def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
|
||||||
SPV_IsCooperativeMatrixType,
|
SPV_IsCooperativeMatrixType,
|
||||||
"any SPIR-V cooperative matrix type">;
|
"any SPIR-V cooperative matrix type">;
|
||||||
|
def SPV_AnyMatrix : DialectType<SPIRV_Dialect, SPV_IsMatrixType,
|
||||||
|
"any SPIR-V matrix type">;
|
||||||
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
|
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
|
||||||
"any SPIR-V runtime array type">;
|
"any SPIR-V runtime array type">;
|
||||||
def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
|
def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
|
||||||
|
@ -3028,11 +3032,11 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
|
||||||
def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
|
def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
|
||||||
def SPV_Composite :
|
def SPV_Composite :
|
||||||
AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
|
AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
|
||||||
SPV_AnyCooperativeMatrix]>;
|
SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>;
|
||||||
def SPV_Type : AnyTypeOf<[
|
def SPV_Type : AnyTypeOf<[
|
||||||
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
|
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
|
||||||
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
|
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
|
||||||
SPV_AnyCooperativeMatrix
|
SPV_AnyCooperativeMatrix, SPV_AnyMatrix
|
||||||
]>;
|
]>;
|
||||||
|
|
||||||
def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
|
def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
|
||||||
|
@ -3160,6 +3164,7 @@ def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>;
|
||||||
def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
|
def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
|
||||||
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
|
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
|
||||||
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
|
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
|
||||||
|
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
|
||||||
def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>;
|
def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>;
|
||||||
def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
|
def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
|
||||||
def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>;
|
def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>;
|
||||||
|
@ -3266,14 +3271,14 @@ def SPV_OpcodeAttr :
|
||||||
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
|
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
|
||||||
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
|
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
|
||||||
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
|
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
|
||||||
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
|
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
|
||||||
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
|
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
|
||||||
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
|
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
|
||||||
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
|
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
|
||||||
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
|
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
|
||||||
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
|
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
|
||||||
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
|
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
|
||||||
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
||||||
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
|
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
|
||||||
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
|
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
|
||||||
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
|
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
//===-- SPIRVMatrixOps.td - MLIR SPIR-V Matrix Ops ---------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file contains matrix operations for the SPIR-V dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef SPIRV_MATRIX_OPS
|
||||||
|
#define SPIRV_MATRIX_OPS
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
|
||||||
|
let summary = "Scale a floating-point matrix.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Result Type must be an OpTypeMatrix whose Column Type is a vector of
|
||||||
|
floating-point type.
|
||||||
|
|
||||||
|
The type of Matrix must be the same as Result Type. Each component in
|
||||||
|
each column in Matrix is multiplied by Scalar.
|
||||||
|
|
||||||
|
Scalar must have the same type as the Component Type in Result Type.
|
||||||
|
|
||||||
|
<!-- End of AutoGen section -->
|
||||||
|
|
||||||
|
```
|
||||||
|
matrix-times-scalar-op ::= ssa-id `=` `spv.MatrixTimesScalar` ssa-use,
|
||||||
|
ssa-use `:` matrix-type `,` float-type `->` matrix-type
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
|
||||||
|
%0 = spv.MatrixTimesScalar %matrix, %scalar :
|
||||||
|
!spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
|
||||||
|
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SPV_AnyMatrix:$matrix,
|
||||||
|
SPV_Float:$scalar
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SPV_AnyMatrix:$result
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO (Hazem): we need just one matrix type given that the input and result
|
||||||
|
// are the same and the scalar's type can be deduced from it.
|
||||||
|
let assemblyFormat = [{
|
||||||
|
operands attr-dict `:` type($matrix) `,` type($scalar) `->` type($result)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let availability = [
|
||||||
|
MinVersion<SPV_V_1_0>,
|
||||||
|
MaxVersion<SPV_V_1_5>,
|
||||||
|
Extension<[]>,
|
||||||
|
Capability<[SPV_C_Matrix]>
|
||||||
|
];
|
||||||
|
|
||||||
|
let verifier = [{ return verifyMatrixTimesScalar(*this); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#endif // SPIRV_MATRIX_OPS
|
|
@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td"
|
||||||
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
|
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
|
||||||
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
|
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
|
||||||
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
|
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
|
||||||
|
include "mlir/Dialect/SPIRV/SPIRVMatrixOps.td"
|
||||||
include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
|
include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
|
||||||
include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
|
include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
|
@ -2760,6 +2760,49 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.MatrixTimesScalar
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
|
||||||
|
// We already checked that result and matrix are both of matrix type in the
|
||||||
|
// auto-generated verify method.
|
||||||
|
|
||||||
|
auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
|
||||||
|
// Check that the scalar type is the same as the matrix components type.
|
||||||
|
if (auto inputMatrixColumns =
|
||||||
|
inputMatrix.getElementType().dyn_cast<VectorType>()) {
|
||||||
|
if (op.scalar().getType() != inputMatrixColumns.getElementType())
|
||||||
|
return op.emitError("input matrix components' type and scaling "
|
||||||
|
"value must have the same type");
|
||||||
|
|
||||||
|
// Note that the next three checks could be done using the AllTypesMatch
|
||||||
|
// trait in the Op definition file but it generates a vague error message.
|
||||||
|
|
||||||
|
// Check that the input and result matrices have the same size
|
||||||
|
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
|
||||||
|
if (inputMatrix.getNumElements() != resultMatrix.getNumElements())
|
||||||
|
return op.emitError("input and result matrices must have "
|
||||||
|
"the same number of columns");
|
||||||
|
|
||||||
|
if (auto resultMatrixColumns =
|
||||||
|
resultMatrix.getElementType().dyn_cast<VectorType>()) {
|
||||||
|
// Check that the input and result matrices' columns have the same type
|
||||||
|
if (inputMatrixColumns.getElementType() !=
|
||||||
|
resultMatrixColumns.getElementType())
|
||||||
|
return op.emitError("input and result matrices' columns must "
|
||||||
|
"have the same component type");
|
||||||
|
|
||||||
|
// Check that the input and result matrices' columns have the same size
|
||||||
|
if (inputMatrixColumns.getNumElements() !=
|
||||||
|
resultMatrixColumns.getNumElements())
|
||||||
|
return op.emitError("input and result matrices' columns must "
|
||||||
|
"have the same size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace spirv {
|
namespace spirv {
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,25 @@
|
||||||
// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
|
// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
|
||||||
|
|
||||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||||
spv.func @matrix_type(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>, %arg1 : i32) "None" {
|
// CHECK-LABEL: @matrix_access_chain
|
||||||
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
|
spv.func @matrix_access_chain(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>, %arg1 : i32) -> !spv.ptr<vector<3xf32>, Function> "None" {
|
||||||
%2 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
|
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>
|
||||||
spv.Return
|
%0 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>
|
||||||
|
spv.ReturnValue %0 : !spv.ptr<vector<3xf32>, Function>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @matrix_times_scalar_1
|
||||||
|
spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
|
||||||
|
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
|
||||||
|
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
|
||||||
|
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @matrix_times_scalar_2
|
||||||
|
spv.func @matrix_times_scalar_2(%arg0 : !spv.matrix<3 x vector<3xf16>>, %arg1 : f16) -> !spv.matrix<3 x vector<3xf16>> "None" {
|
||||||
|
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>>
|
||||||
|
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>>
|
||||||
|
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||||
|
|
||||||
|
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||||
|
// CHECK-LABEL: @matrix_times_scalar
|
||||||
|
spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
|
||||||
|
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
|
||||||
|
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
|
||||||
|
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @input_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f16) -> () {
|
||||||
|
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
|
||||||
|
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f16 -> !spv.matrix<3 x vector<3xf32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @input_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f64) -> () {
|
||||||
|
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
|
||||||
|
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f64 -> !spv.matrix<3 x vector<3xf32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @input_output_component_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
|
||||||
|
// expected-error @+1 {{input and result matrices' columns must have the same component type}}
|
||||||
|
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf64>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @input_output_size_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
|
||||||
|
// expected-error @+1 {{input and result matrices must have the same number of columns}}
|
||||||
|
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue