[mlir][spirv] Add OpGroupBroadcast

OpGroupBroadcast added to SPIRV dialect

Differential Revision: https://reviews.llvm.org/D85435
This commit is contained in:
Artur Bialas 2020-08-10 09:39:27 -07:00 committed by Thomas Raoux
parent 4061d9e42c
commit a8fe40d973
5 changed files with 162 additions and 2 deletions

View File

@ -3231,6 +3231,7 @@ def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional",
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>;
def SPV_OC_OpGroupBroadcast : I32EnumAttrCase<"OpGroupBroadcast", 263>;
def SPV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
@ -3297,8 +3298,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,

View File

@ -17,6 +17,82 @@
// -----
def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
[NoSideEffect, AllTypesMatch<["value", "result"]>]> {
let summary = [{
Return the Value of the invocation identified by the local id LocalId to
all invocations in the group.
}];
let description = [{
All invocations of this module within Execution must reach this point of
execution.
Behavior is undefined if this instruction is used in control flow that
is non-uniform within Execution.
Result Type must be a scalar or vector of floating-point type, integer
type, or Boolean type.
Execution must be Workgroup or Subgroup Scope.
The type of Value must be the same as Result Type.
LocalId must be an integer datatype. It can be a scalar, or a vector
with 2 components or a vector with 3 components. LocalId must be the
same for all invocations in the group.
<!-- End of AutoGen section -->
```
scope ::= `"Workgroup"` | `"Subgroup"`
integer-float-scalar-vector-type ::= integer-type | float-type |
`vector<` integer-literal `x` integer-type `>` |
`vector<` integer-literal `x` float-type `>`
localid-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
group-broadcast-op ::= ssa-id `=` `spv.GroupBroadcast` scope ssa_use,
ssa_use `:` integer-float-scalar-vector-type `,` localid-type
```mlir
#### Example:
```
%scalar_value = ... : f32
%vector_value = ... : vector<4xf32>
%scalar_localid = ... : i32
%vector_localid = ... : vector<3xi32>
%0 = spv.GroupBroadcast "Subgroup" %scalar_value, %scalar_localid : f32, i32
%1 = spv.GroupBroadcast "Workgroup" %vector_value, %vector_localid :
vector<4xf32>, vector<3xi32>
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Groups]>
];
let arguments = (ins
SPV_ScopeAttr:$execution_scope,
SPV_Type:$value,
SPV_ScalarOrVectorOf<SPV_Integer>:$localid
);
let results = (outs
SPV_Type:$result
);
let assemblyFormat = [{
$execution_scope operands attr-dict `:` type($value) `,` type($localid)
}];
}
// -----
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
let summary = "See extension SPV_KHR_shader_ballot";

View File

@ -1993,6 +1993,25 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.GroupBroadcast
//===----------------------------------------------------------------------===//
static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
spirv::Scope scope = broadcastOp.execution_scope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return broadcastOp.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
return broadcastOp.emitOpError("localid is a vector and can be with only "
" 2 or 3 components, actual number is ")
<< localIdTy.getNumElements();
return success();
}
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformBallotOp
//===----------------------------------------------------------------------===//

View File

@ -7,4 +7,16 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
// CHECK-LABEL: @group_broadcast_1
spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" {
// CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32
spv.ReturnValue %0: f32
}
// CHECK-LABEL: @group_broadcast_2
spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" {
// CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32>
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
spv.ReturnValue %0: f32
}
}

View File

@ -9,3 +9,55 @@ func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
return %0: vector<4xi32>
}
// -----
//===----------------------------------------------------------------------===//
// spv.GroupBroadcast
//===----------------------------------------------------------------------===//
func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 {
// CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32
return %0: f32
}
// -----
func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 {
// CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32>
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
return %0: f32
}
// -----
func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> {
// CHECK: spv.GroupBroadcast "Subgroup" %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : vector<4xf32>, vector<3xi32>
return %0: vector<4xf32>
}
// -----
func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupBroadcast "Device" %value, %localid : f32, vector<3xi32>
return %0: f32
}
// -----
func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 {
// expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<3xf32>
return %0: f32
}
// -----
func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 {
// expected-error @+1 {{localid is a vector and can be with only 2 or 3 components, actual number is 4}}
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32>
return %0: f32
}