[mlir][spirv] Define spv.VectorTimesScalar op

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D121247
This commit is contained in:
Lei Zhang 2022-03-08 15:58:31 -05:00
parent f8fb2aff70
commit cfb9e474ae
5 changed files with 106 additions and 26 deletions

View File

@ -564,6 +564,40 @@ def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv",
// -----
def SPV_VectorTimesScalarOp : SPV_Op<"VectorTimesScalar", [NoSideEffect]> {
let summary = "Scale a floating-point vector.";
let description = [{
Result Type must be a vector of floating-point type.
The type of Vector must be the same as Result Type. Each component of
Vector is multiplied by Scalar.
Scalar must have the same type as the Component Type in Result Type.
<!-- End of AutoGen section -->
#### Example:
```mlir
%0 = spv.VectorTimesScalar %vector, %scalar : vector<4xf32>
```
}];
let arguments = (ins
VectorOfLengthAndType<[2, 3, 4], [SPV_Float]>:$vector,
SPV_Float:$scalar
);
let results = (outs
VectorOfLengthAndType<[2, 3, 4], [SPV_Float]>:$result
);
let assemblyFormat = "operands attr-dict `:` `(` type(operands) `)` `->` type($result)";
}
// -----
def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod",
SPV_Integer,
[UnsignedOp, UsableInSpecConstantOp]> {

View File

@ -4078,6 +4078,7 @@ def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>;
def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>;
@ -4202,32 +4203,33 @@ def SPV_OpcodeAttr :
SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered,
SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange,
SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpVectorTimesScalar,
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan,
SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual,
SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicExchange,
SPV_OC_OpAtomicCompareExchange, SPV_OC_OpAtomicCompareExchangeWeak,
SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine,
SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot,
SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,

View File

@ -4424,6 +4424,19 @@ LogicalResult spirv::PtrAccessChainOp::verify() {
return verifyAccessChain(*this, indices());
}
//===----------------------------------------------------------------------===//
// spv.VectorTimesScalarOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::VectorTimesScalarOp::verify() {
if (vector().getType() != getType())
return emitOpError("vector operand and result type mismatch");
auto scalarType = getType().cast<VectorType>().getElementType();
if (scalar().getType() != scalarType)
return emitOpError("scalar operand and result element type match");
return success();
}
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"

View File

@ -219,3 +219,29 @@ func @umod_scalar(%arg: i32) -> i32 {
return %0 : i32
}
// -----
//===----------------------------------------------------------------------===//
// spv.VectorTimesScalar
//===----------------------------------------------------------------------===//
func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<4xf32> {
// CHECK: spv.VectorTimesScalar %{{.+}}, %{{.+}} : (vector<4xf32>, f32) -> vector<4xf32>
%0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<4xf32>
return %0 : vector<4xf32>
}
// -----
func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f16) -> vector<4xf32> {
// expected-error @+1 {{scalar operand and result element type match}}
%0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f16) -> vector<4xf32>
return %0 : vector<4xf32>
}
// -----
func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3xf32> {
// expected-error @+1 {{vector operand and result type mismatch}}
%0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32>
return %0 : vector<3xf32>
}

View File

@ -81,4 +81,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%0 = spv.SRem %arg0, %arg1 : vector<4xi32>
spv.Return
}
spv.func @vector_times_scalar(%arg0 : vector<4xf32>, %arg1 : f32) "None" {
// CHECK: {{%.*}} = spv.VectorTimesScalar {{%.*}}, {{%.*}} : (vector<4xf32>, f32) -> vector<4xf32>
%0 = spv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
spv.Return
}
}