[mlir][spirv] Add MatrixTimesScalar operation

Summary:
- Define the MatrixTimesScalar operation and add roundtrip tests.
- Added a new base class for matrix-specific operations to avoid invalid operands type mismatch check.
- Created a separate Matrix arithmetic operations td file to add more operations in the future.
- Augmented the automatically generated verify method to print more fine-grained error messages.
- Made minor Updates to the matrix type tests.

Reviewers: antiagainst, rriddle, mravishankar

Reviewed By: antiagainst

Subscribers: mehdi_amini, jpienaar, shauheen, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, bader, grosul1, frgossen, Kayjukh, jurahul, msifontes

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D81677
This commit is contained in:
HazemAbdelhafez 2020-06-15 21:50:18 -04:00
parent ac20150e29
commit 55d53d4f54
6 changed files with 194 additions and 14 deletions

View File

@ -2994,10 +2994,12 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
def SPV_IsCooperativeMatrixType : def SPV_IsCooperativeMatrixType :
CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">; CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">;
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories. // for the definition of the following types and type categories.
@ -3018,6 +3020,8 @@ def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect, def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
SPV_IsCooperativeMatrixType, SPV_IsCooperativeMatrixType,
"any SPIR-V cooperative matrix type">; "any SPIR-V cooperative matrix type">;
def SPV_AnyMatrix : DialectType<SPIRV_Dialect, SPV_IsMatrixType,
"any SPIR-V matrix type">;
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType, def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
"any SPIR-V runtime array type">; "any SPIR-V runtime array type">;
def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType, def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
@ -3028,11 +3032,11 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
def SPV_Composite : def SPV_Composite :
AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
SPV_AnyCooperativeMatrix]>; SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>;
def SPV_Type : AnyTypeOf<[ def SPV_Type : AnyTypeOf<[
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
SPV_AnyCooperativeMatrix SPV_AnyCooperativeMatrix, SPV_AnyMatrix
]>; ]>;
def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
@ -3160,6 +3164,7 @@ def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>;
def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>; def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>;
def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>;
@ -3266,14 +3271,14 @@ def SPV_OpcodeAttr :
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

View File

@ -0,0 +1,75 @@
//===-- SPIRVMatrixOps.td - MLIR SPIR-V Matrix Ops ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains matrix operations for the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef SPIRV_MATRIX_OPS
#define SPIRV_MATRIX_OPS
// -----
def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
let summary = "Scale a floating-point matrix.";
let description = [{
Result Type must be an OpTypeMatrix whose Column Type is a vector of
floating-point type.
The type of Matrix must be the same as Result Type. Each component in
each column in Matrix is multiplied by Scalar.
Scalar must have the same type as the Component Type in Result Type.
<!-- End of AutoGen section -->
```
matrix-times-scalar-op ::= ssa-id `=` `spv.MatrixTimesScalar` ssa-use,
ssa-use `:` matrix-type `,` float-type `->` matrix-type
```
#### Example:
```mlir
%0 = spv.MatrixTimesScalar %matrix, %scalar :
!spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
```
}];
let arguments = (ins
SPV_AnyMatrix:$matrix,
SPV_Float:$scalar
);
let results = (outs
SPV_AnyMatrix:$result
);
// TODO (Hazem): we need just one matrix type given that the input and result
// are the same and the scalar's type can be deduced from it.
let assemblyFormat = [{
operands attr-dict `:` type($matrix) `,` type($scalar) `->` type($result)
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Matrix]>
];
let verifier = [{ return verifyMatrixTimesScalar(*this); }];
}
// -----
#endif // SPIRV_MATRIX_OPS

View File

@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td"
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
include "mlir/Dialect/SPIRV/SPIRVMatrixOps.td"
include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td" include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
include "mlir/Dialect/SPIRV/SPIRVStructureOps.td" include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"

View File

@ -2760,6 +2760,49 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// spv.MatrixTimesScalar
//===----------------------------------------------------------------------===//
static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
// We already checked that result and matrix are both of matrix type in the
// auto-generated verify method.
auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
// Check that the scalar type is the same as the matrix components type.
if (auto inputMatrixColumns =
inputMatrix.getElementType().dyn_cast<VectorType>()) {
if (op.scalar().getType() != inputMatrixColumns.getElementType())
return op.emitError("input matrix components' type and scaling "
"value must have the same type");
// Note that the next three checks could be done using the AllTypesMatch
// trait in the Op definition file but it generates a vague error message.
// Check that the input and result matrices have the same size
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
if (inputMatrix.getNumElements() != resultMatrix.getNumElements())
return op.emitError("input and result matrices must have "
"the same number of columns");
if (auto resultMatrixColumns =
resultMatrix.getElementType().dyn_cast<VectorType>()) {
// Check that the input and result matrices' columns have the same type
if (inputMatrixColumns.getElementType() !=
resultMatrixColumns.getElementType())
return op.emitError("input and result matrices' columns must "
"have the same component type");
// Check that the input and result matrices' columns have the same size
if (inputMatrixColumns.getNumElements() !=
resultMatrixColumns.getNumElements())
return op.emitError("input and result matrices' columns must "
"have the same size");
}
}
return success();
}
namespace mlir { namespace mlir {
namespace spirv { namespace spirv {

View File

@ -1,10 +1,25 @@
// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> { spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.func @matrix_type(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>, %arg1 : i32) "None" { // CHECK-LABEL: @matrix_access_chain
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer> spv.func @matrix_access_chain(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>, %arg1 : i32) -> !spv.ptr<vector<3xf32>, Function> "None" {
%2 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer> // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>
spv.Return %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>
spv.ReturnValue %0 : !spv.ptr<vector<3xf32>, Function>
}
// CHECK-LABEL: @matrix_times_scalar_1
spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_times_scalar_2
spv.func @matrix_times_scalar_2(%arg0 : !spv.matrix<3 x vector<3xf16>>, %arg1 : f16) -> !spv.matrix<3 x vector<3xf16>> "None" {
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>>
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>>
} }
} }

View File

@ -0,0 +1,41 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @matrix_times_scalar
spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
}
// -----
func @input_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f16) -> () {
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f16 -> !spv.matrix<3 x vector<3xf32>>
}
// -----
func @input_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f64) -> () {
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f64 -> !spv.matrix<3 x vector<3xf32>>
}
// -----
func @input_output_component_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
// expected-error @+1 {{input and result matrices' columns must have the same component type}}
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf64>>
}
// -----
func @input_output_size_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
// expected-error @+1 {{input and result matrices must have the same number of columns}}
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>>
}