forked from OSchip/llvm-project
[mlir][spirv] Add OpGroupBroadcast
OpGroupBroadcast added to SPIRV dialect Differential Revision: https://reviews.llvm.org/D85435
This commit is contained in:
parent
4061d9e42c
commit
a8fe40d973
|
@ -3231,6 +3231,7 @@ def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional",
|
|||
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
|
||||
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
|
||||
def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>;
|
||||
def SPV_OC_OpGroupBroadcast : I32EnumAttrCase<"OpGroupBroadcast", 263>;
|
||||
def SPV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
|
||||
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
|
||||
def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
|
||||
|
@ -3297,8 +3298,8 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
|
||||
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
|
||||
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
|
||||
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
|
||||
SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
|
||||
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
|
||||
SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
|
||||
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
|
||||
SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
|
||||
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
|
||||
|
|
|
@ -17,6 +17,82 @@
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
|
||||
[NoSideEffect, AllTypesMatch<["value", "result"]>]> {
|
||||
let summary = [{
|
||||
Return the Value of the invocation identified by the local id LocalId to
|
||||
all invocations in the group.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
All invocations of this module within Execution must reach this point of
|
||||
execution.
|
||||
|
||||
Behavior is undefined if this instruction is used in control flow that
|
||||
is non-uniform within Execution.
|
||||
|
||||
Result Type must be a scalar or vector of floating-point type, integer
|
||||
type, or Boolean type.
|
||||
|
||||
Execution must be Workgroup or Subgroup Scope.
|
||||
|
||||
The type of Value must be the same as Result Type.
|
||||
|
||||
LocalId must be an integer datatype. It can be a scalar, or a vector
|
||||
with 2 components or a vector with 3 components. LocalId must be the
|
||||
same for all invocations in the group.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
scope ::= `"Workgroup"` | `"Subgroup"`
|
||||
integer-float-scalar-vector-type ::= integer-type | float-type |
|
||||
`vector<` integer-literal `x` integer-type `>` |
|
||||
`vector<` integer-literal `x` float-type `>`
|
||||
localid-type ::= integer-type |
|
||||
`vector<` integer-literal `x` integer-type `>`
|
||||
group-broadcast-op ::= ssa-id `=` `spv.GroupBroadcast` scope ssa_use,
|
||||
ssa_use `:` integer-float-scalar-vector-type `,` localid-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%scalar_value = ... : f32
|
||||
%vector_value = ... : vector<4xf32>
|
||||
%scalar_localid = ... : i32
|
||||
%vector_localid = ... : vector<3xi32>
|
||||
%0 = spv.GroupBroadcast "Subgroup" %scalar_value, %scalar_localid : f32, i32
|
||||
%1 = spv.GroupBroadcast "Workgroup" %vector_value, %vector_localid :
|
||||
vector<4xf32>, vector<3xi32>
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Groups]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScopeAttr:$execution_scope,
|
||||
SPV_Type:$value,
|
||||
SPV_ScalarOrVectorOf<SPV_Integer>:$localid
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_Type:$result
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$execution_scope operands attr-dict `:` type($value) `,` type($localid)
|
||||
}];
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
|
||||
let summary = "See extension SPV_KHR_shader_ballot";
|
||||
|
||||
|
|
|
@ -1993,6 +1993,25 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.GroupBroadcast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
|
||||
spirv::Scope scope = broadcastOp.execution_scope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return broadcastOp.emitOpError(
|
||||
"execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
|
||||
if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
|
||||
return broadcastOp.emitOpError("localid is a vector and can be with only "
|
||||
" 2 or 3 components, actual number is ")
|
||||
<< localIdTy.getNumElements();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.GroupNonUniformBallotOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -7,4 +7,16 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
|||
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
|
||||
spv.ReturnValue %0: vector<4xi32>
|
||||
}
|
||||
// CHECK-LABEL: @group_broadcast_1
|
||||
spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" {
|
||||
// CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
|
||||
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32
|
||||
spv.ReturnValue %0: f32
|
||||
}
|
||||
// CHECK-LABEL: @group_broadcast_2
|
||||
spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" {
|
||||
// CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32>
|
||||
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
|
||||
spv.ReturnValue %0: f32
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,3 +9,55 @@ func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
|
|||
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
|
||||
return %0: vector<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.GroupBroadcast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 {
|
||||
// CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
|
||||
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 {
|
||||
// CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32>
|
||||
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> {
|
||||
// CHECK: spv.GroupBroadcast "Subgroup" %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
|
||||
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : vector<4xf32>, vector<3xi32>
|
||||
return %0: vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
|
||||
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
|
||||
%0 = spv.GroupBroadcast "Device" %value, %localid : f32, vector<3xi32>
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 {
|
||||
// expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
|
||||
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<3xf32>
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 {
|
||||
// expected-error @+1 {{localid is a vector and can be with only 2 or 3 components, actual number is 4}}
|
||||
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32>
|
||||
return %0: f32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue