diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 1fa72bf4dcab..83150dad514d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3256,6 +3256,7 @@ def SPV_OC_OpGroupBroadcast : I32EnumAttrCase<"OpGroupBroadcast", 263 def SPV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>; def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>; +def SPV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>; def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; def SPV_OC_OpGroupNonUniformIAdd : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>; def SPV_OC_OpGroupNonUniformFAdd : I32EnumAttrCase<"OpGroupNonUniformFAdd", 350>; @@ -3323,16 +3324,16 @@ def SPV_OpcodeAttr : SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, - SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, - SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, - SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin, - SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin, - SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax, - SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, - SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV, - SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, - SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL, - SPV_OC_OpSubgroupBlockWriteINTEL + SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot, + SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, + SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, + SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin, + SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax, + SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax, + SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV, + SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV, + SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV, + SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td index 34be336bb2a5..da3da3050efc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -105,6 +105,77 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> { // ----- +def SPV_GroupNonUniformBroadcastOp : SPV_Op<"GroupNonUniformBroadcast", + [NoSideEffect, AllTypesMatch<["value", "result"]>]> { + let summary = [{ + Return the Value of the invocation identified by the id Id to all active + invocations in the group. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type, integer + type, or Boolean type. + + Execution must be Workgroup or Subgroup Scope. + + The type of Value must be the same as Result Type. + + Id must be a scalar of integer type, whose Signedness operand is 0. + + Before version 1.5, Id must come from a constant instruction. Starting + with version 1.5, Id must be dynamically uniform. + + The resulting value is undefined if Id is an inactive invocation, or is + greater than or equal to the size of the group. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + integer-float-scalar-vector-type ::= integer-type | float-type | + `vector<` integer-literal `x` integer-type `>` | + `vector<` integer-literal `x` float-type `>` + group-non-uniform-broadcast-op ::= ssa-id `=` + `spv.GroupNonUniformBroadcast` scope ssa_use, + ssa_use `:` integer-float-scalar-vector-type `,` integer-type + ```mlir + + #### Example: + + ``` + %scalar_value = ... : f32 + %vector_value = ... : vector<4xf32> + %id = ... : i32 + %0 = spv.GroupNonUniformBroadcast "Subgroup" %scalar_value, %id : f32, i32 + %1 = spv.GroupNonUniformBroadcast "Workgroup" %vector_value, %id : + vector<4xf32>, i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_GroupNonUniformBallot]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope, + SPV_Type:$value, + SPV_Integer:$id + ); + + let results = (outs + SPV_Type:$result + ); + + let assemblyFormat = [{ + $execution_scope operands attr-dict `:` type($value) `,` type($id) + }]; +} + +// ----- + def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> { let summary = [{ Result is true only in the active invocation with the lowest id in the @@ -368,8 +439,8 @@ def SPV_GroupNonUniformFMulOp : def SPV_GroupNonUniformIAddOp : SPV_GroupNonUniformArithmeticOp<"GroupNonUniformIAdd", SPV_Integer, []> { let summary = [{ - An integer add group operation of all Value operands contributed active - by invocations in the group. + An integer add group operation of all Value operands contributed by + active invocations in the group. }]; let description = [{ diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index a16dc1c8bc35..a01177132b27 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/Dialect/SPIRV/TargetAndABI.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/FunctionImplementation.h" @@ -2043,6 +2044,32 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformBroadcast +//===----------------------------------------------------------------------===// + +static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) { + spirv::Scope scope = broadcastOp.execution_scope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return broadcastOp.emitOpError( + "execution scope must be 'Workgroup' or 'Subgroup'"); + + // SPIR-V spec: "Before version 1.5, Id must come from a + // constant instruction. + auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext()); + if (auto spirvModule = broadcastOp.getParentOfType()) + targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); + + if (targetEnv.getVersion() < spirv::Version::V_1_5) { + auto *idOp = broadcastOp.id().getDefiningOp(); + if (!idOp || !isa(idOp)) // for spec constant + return broadcastOp.emitOpError("id must be the result of a constant op"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // spv.SubgroupBlockReadINTEL //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir index ab714dfbaa00..f7b8f6cfc185 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir @@ -8,6 +8,14 @@ spv.module Logical GLSL450 requires #spv.vce { spv.ReturnValue %0: vector<4xi32> } + // CHECK-LABEL: @group_non_uniform_broadcast + spv.func @group_non_uniform_broadcast(%value: f32) -> f32 "None" { + %one = spv.constant 1 : i32 + // CHECK: spv.GroupNonUniformBroadcast "Subgroup" %{{.*}}, %{{.*}} : f32, i32 + %0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %one : f32, i32 + spv.ReturnValue %0: f32 + } + // CHECK-LABEL: @group_non_uniform_elect spv.func @group_non_uniform_elect() -> i1 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1 diff --git a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir index 86c3c2886a4f..5839ee7c5627 100644 --- a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir @@ -28,6 +28,45 @@ func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> { // ----- +//===----------------------------------------------------------------------===// +// spv.NonUniformGroupBroadcast +//===----------------------------------------------------------------------===// + +func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 { + %one = spv.constant 1 : i32 + // CHECK: spv.GroupNonUniformBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32 + %0 = spv.GroupNonUniformBroadcast "Workgroup" %value, %one : f32, i32 + return %0: f32 +} + +// ----- + +func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4xf32> { + %one = spv.constant 1 : i32 + // CHECK: spv.GroupNonUniformBroadcast "Subgroup" %{{.*}}, %{{.*}} : vector<4xf32>, i32 + %0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %one : vector<4xf32>, i32 + return %0: vector<4xf32> +} + +// ----- + +func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 { + %one = spv.constant 1 : i32 + // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} + %0 = spv.GroupNonUniformBroadcast "Device" %value, %one : f32, i32 + return %0: f32 +} + +// ----- + +func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid: i32) -> f32 { + // expected-error @+1 {{id must be the result of a constant op}} + %0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %localid : f32, i32 + return %0: f32 +} + +// ----- + //===----------------------------------------------------------------------===// // spv.GroupNonUniformElect //===----------------------------------------------------------------------===//