forked from OSchip/llvm-project
[spirv] Allow block arguments on spv.Branch(Conditional)
We will use block arguments as the way to model SPIR-V OpPhi in the SPIR-V dialect. This CL also adds a few useful helper methods to both ops to get the block arguments. Also added tests for branch weight (de)serialization. PiperOrigin-RevId: 275960797
This commit is contained in:
parent
5f867d26b4
commit
d9fe892e42
|
@ -46,23 +46,29 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
|
||||||
|
|
||||||
``` {.ebnf}
|
``` {.ebnf}
|
||||||
branch-op ::= `spv.Branch` successor
|
branch-op ::= `spv.Branch` successor
|
||||||
|
successor ::= bb-id branch-use-list?
|
||||||
|
branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
|
||||||
```
|
```
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
```
|
```
|
||||||
spv.Branch ^target
|
spv.Branch ^target
|
||||||
|
spv.Branch ^target(%0, %1: i32, f32)
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins);
|
let arguments = (ins
|
||||||
|
Variadic<AnyType>:$block_arguments
|
||||||
|
);
|
||||||
|
|
||||||
let results = (outs);
|
let results = (outs);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<
|
OpBuilder<
|
||||||
"Builder *, OperationState &state, Block *successor", [{
|
"Builder *, OperationState &state, "
|
||||||
state.addSuccessor(successor, {});
|
"Block *successor, ArrayRef<Value *> arguments = {}", [{
|
||||||
|
state.addSuccessor(successor, arguments);
|
||||||
}]
|
}]
|
||||||
>
|
>
|
||||||
];
|
];
|
||||||
|
@ -70,7 +76,13 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
/// Returns the branch target block.
|
||||||
Block *getTarget() { return getOperation()->getSuccessor(0); }
|
Block *getTarget() { return getOperation()->getSuccessor(0); }
|
||||||
|
|
||||||
|
/// Returns the block arguments.
|
||||||
|
operand_range getBlockArguments() {
|
||||||
|
return getOperation()->getSuccessorOperands(0);
|
||||||
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let autogenSerialization = 0;
|
let autogenSerialization = 0;
|
||||||
|
@ -105,17 +117,21 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
||||||
branch-conditional-op ::= `spv.BranchConditional` ssa-use
|
branch-conditional-op ::= `spv.BranchConditional` ssa-use
|
||||||
(`[` integer-literal, integer-literal `]`)?
|
(`[` integer-literal, integer-literal `]`)?
|
||||||
`,` successor `,` successor
|
`,` successor `,` successor
|
||||||
|
successor ::= bb-id branch-use-list?
|
||||||
|
branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
|
||||||
```
|
```
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
```
|
```
|
||||||
spv.BranchConditional %condition, ^true_branch, ^false_branch
|
spv.BranchConditional %condition, ^true_branch, ^false_branch
|
||||||
|
spv.BranchConditional %condition, ^true_branch(%0: i32), ^false_branch(%1: i32)
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SPV_Bool:$condition,
|
SPV_Bool:$condition,
|
||||||
|
Variadic<AnyType>:$branch_arguments,
|
||||||
OptionalAttr<I32ArrayAttr>:$branch_weights
|
OptionalAttr<I32ArrayAttr>:$branch_weights
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -124,12 +140,13 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<
|
OpBuilder<
|
||||||
"Builder *builder, OperationState &state, Value *condition, "
|
"Builder *builder, OperationState &state, Value *condition, "
|
||||||
"Block *trueBranch, Block *falseBranch, "
|
"Block *trueBlock, ArrayRef<Value *> trueArguments, "
|
||||||
"Optional<std::pair<uint32_t, uint32_t>> weights",
|
"Block *falseBlock, ArrayRef<Value *> falseArguments, "
|
||||||
|
"Optional<std::pair<uint32_t, uint32_t>> weights = {}",
|
||||||
[{
|
[{
|
||||||
state.addOperands(condition);
|
state.addOperands(condition);
|
||||||
state.addSuccessor(trueBranch, {});
|
state.addSuccessor(trueBlock, trueArguments);
|
||||||
state.addSuccessor(falseBranch, {});
|
state.addSuccessor(falseBlock, falseArguments);
|
||||||
if (weights) {
|
if (weights) {
|
||||||
auto attr =
|
auto attr =
|
||||||
builder->getI32ArrayAttr({static_cast<int32_t>(weights->first),
|
builder->getI32ArrayAttr({static_cast<int32_t>(weights->first),
|
||||||
|
@ -145,12 +162,57 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
||||||
let autogenSerialization = 0;
|
let autogenSerialization = 0;
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
// Branch indices into the successor list.
|
/// Branch indices into the successor list.
|
||||||
enum { kTrueIndex = 0, kFalseIndex = 1 };
|
enum { kTrueIndex = 0, kFalseIndex = 1 };
|
||||||
|
|
||||||
|
/// Returns the target block for the true branch.
|
||||||
Block *getTrueBlock() { return getOperation()->getSuccessor(kTrueIndex); }
|
Block *getTrueBlock() { return getOperation()->getSuccessor(kTrueIndex); }
|
||||||
|
|
||||||
|
/// Returns the target block for the false branch.
|
||||||
Block *getFalseBlock() { return getOperation()->getSuccessor(kFalseIndex); }
|
Block *getFalseBlock() { return getOperation()->getSuccessor(kFalseIndex); }
|
||||||
|
|
||||||
|
/// Returns the number of arguments to the true target block.
|
||||||
|
unsigned getNumTrueBlockArguments() {
|
||||||
|
return getNumSuccessorOperands(kTrueIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of arguments to the false target block.
|
||||||
|
unsigned getNumFalseBlockArguments() {
|
||||||
|
return getNumSuccessorOperands(kFalseIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterator and range support for true target block arguments.
|
||||||
|
operand_iterator true_block_argument_begin() {
|
||||||
|
return operand_begin() + getTrueBlockArgumentIndex();
|
||||||
|
}
|
||||||
|
operand_iterator true_block_argument_end() {
|
||||||
|
return true_block_argument_begin() + getNumTrueBlockArguments();
|
||||||
|
}
|
||||||
|
operand_range getTrueBlockArguments() {
|
||||||
|
return {true_block_argument_begin(), true_block_argument_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterator and range support for false target block arguments.
|
||||||
|
operand_iterator false_block_argument_begin() {
|
||||||
|
return true_block_argument_end();
|
||||||
|
}
|
||||||
|
operand_iterator false_block_argument_end() {
|
||||||
|
return false_block_argument_begin() + getNumFalseBlockArguments();
|
||||||
|
}
|
||||||
|
operand_range getFalseBlockArguments() {
|
||||||
|
return {false_block_argument_begin(), false_block_argument_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Gets the index of the first true block argument in the operand list.
|
||||||
|
unsigned getTrueBlockArgumentIndex() {
|
||||||
|
return 1; // Omit the first argument, which is the condition.
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the index of the first false block argument in the operand list.
|
||||||
|
unsigned getFalseBlockArgumentIndex() {
|
||||||
|
return getTrueBlockArgumentIndex() + getNumTrueBlockArguments();
|
||||||
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1489,8 +1489,10 @@ Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
|
||||||
weights = std::make_pair(operands[3], operands[4]);
|
weights = std::make_pair(operands[3], operands[4]);
|
||||||
}
|
}
|
||||||
|
|
||||||
opBuilder.create<spirv::BranchConditionalOp>(unknownLoc, condition, trueBlock,
|
opBuilder.create<spirv::BranchConditionalOp>(
|
||||||
falseBlock, weights);
|
unknownLoc, condition, trueBlock,
|
||||||
|
/*trueArguments=*/ArrayRef<Value *>(), falseBlock,
|
||||||
|
/*falseArguments=*/ArrayRef<Value *>(), weights);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,8 +22,8 @@ spv.module "Logical" "GLSL450" {
|
||||||
%val0 = spv.Load "Function" %var : i32
|
%val0 = spv.Load "Function" %var : i32
|
||||||
// CHECK-NEXT: spv.SLessThan
|
// CHECK-NEXT: spv.SLessThan
|
||||||
%cmp = spv.SLessThan %val0, %count : i32
|
%cmp = spv.SLessThan %val0, %count : i32
|
||||||
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb2, ^bb4
|
// CHECK-NEXT: spv.BranchConditional %{{.*}} [1, 1], ^bb2, ^bb4
|
||||||
spv.BranchConditional %cmp, ^body, ^merge
|
spv.BranchConditional %cmp [1, 1], ^body, ^merge
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb2:
|
// CHECK-NEXT: ^bb2:
|
||||||
^body:
|
^body:
|
||||||
|
|
|
@ -15,8 +15,8 @@ spv.module "Logical" "GLSL450" {
|
||||||
// CHECK-NEXT: spv.constant 0
|
// CHECK-NEXT: spv.constant 0
|
||||||
// CHECK-NEXT: spv.Variable
|
// CHECK-NEXT: spv.Variable
|
||||||
spv.selection {
|
spv.selection {
|
||||||
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
|
// CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2
|
||||||
spv.BranchConditional %cond, ^then, ^else
|
spv.BranchConditional %cond [5, 10], ^then, ^else
|
||||||
|
|
||||||
// CHECK-NEXT: ^bb1:
|
// CHECK-NEXT: ^bb1:
|
||||||
^then:
|
^then:
|
||||||
|
|
|
@ -13,6 +13,16 @@ func @branch() -> () {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @branch_argument() -> () {
|
||||||
|
%zero = spv.constant 0 : i32
|
||||||
|
// CHECK: spv.Branch ^bb1(%{{.*}}, %{{.*}} : i32, i32)
|
||||||
|
spv.Branch ^next(%zero, %zero: i32, i32)
|
||||||
|
^next(%arg0: i32, %arg1: i32):
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @missing_accessor() -> () {
|
func @missing_accessor() -> () {
|
||||||
spv.Branch
|
spv.Branch
|
||||||
// expected-error @+1 {{expected block name}}
|
// expected-error @+1 {{expected block name}}
|
||||||
|
@ -32,16 +42,6 @@ func @wrong_accessor_count() -> () {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @accessor_argument_disallowed() -> () {
|
|
||||||
%zero = spv.constant 0 : i32
|
|
||||||
// expected-error @+1 {{requires zero operands}}
|
|
||||||
"spv.Branch"()[^next(%zero : i32)] : () -> ()
|
|
||||||
^next(%arg: i32):
|
|
||||||
spv.Return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.BranchConditional
|
// spv.BranchConditional
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -60,6 +60,24 @@ func @cond_branch() -> () {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @cond_branch_argument() -> () {
|
||||||
|
%true = spv.constant true
|
||||||
|
%zero = spv.constant 0 : i32
|
||||||
|
// CHECK: spv.BranchConditional %{{.*}}, ^bb1(%{{.*}}, %{{.*}} : i32, i32), ^bb2
|
||||||
|
spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
|
||||||
|
^true1(%arg0: i32, %arg1: i32):
|
||||||
|
// CHECK: spv.BranchConditional %{{.*}}, ^bb3, ^bb4(%{{.*}}, %{{.*}} : i32, i32)
|
||||||
|
spv.BranchConditional %true, ^true2, ^false2(%zero, %zero: i32, i32)
|
||||||
|
^false1:
|
||||||
|
spv.Return
|
||||||
|
^true2:
|
||||||
|
spv.Return
|
||||||
|
^false2(%arg3: i32, %arg4: i32):
|
||||||
|
spv.Return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @cond_branch_with_weights() -> () {
|
func @cond_branch_with_weights() -> () {
|
||||||
%true = spv.constant true
|
%true = spv.constant true
|
||||||
// CHECK: spv.BranchConditional %{{.*}} [5, 10]
|
// CHECK: spv.BranchConditional %{{.*}} [5, 10]
|
||||||
|
@ -108,18 +126,6 @@ func @wrong_accessor_count() -> () {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @accessor_argument_disallowed() -> () {
|
|
||||||
%true = spv.constant true
|
|
||||||
// expected-error @+1 {{requires a single operand}}
|
|
||||||
"spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> ()
|
|
||||||
^one(%arg : i1):
|
|
||||||
spv.Return
|
|
||||||
^two:
|
|
||||||
spv.Return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @wrong_number_of_weights() -> () {
|
func @wrong_number_of_weights() -> () {
|
||||||
%true = spv.constant true
|
%true = spv.constant true
|
||||||
// expected-error @+1 {{must have exactly two branch weights}}
|
// expected-error @+1 {{must have exactly two branch weights}}
|
||||||
|
|
Loading…
Reference in New Issue