forked from OSchip/llvm-project
[spirv] Allow block arguments on spv.Branch(Conditional)
We will use block arguments as the way to model SPIR-V OpPhi in the SPIR-V dialect. This CL also adds a few useful helper methods to both ops to get the block arguments. Also added tests for branch weight (de)serialization. PiperOrigin-RevId: 275960797
This commit is contained in:
parent
5f867d26b4
commit
d9fe892e42
|
@ -46,23 +46,29 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
|
|||
|
||||
``` {.ebnf}
|
||||
branch-op ::= `spv.Branch` successor
|
||||
successor ::= bb-id branch-use-list?
|
||||
branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
spv.Branch ^target
|
||||
spv.Branch ^target(%0, %1: i32, f32)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$block_arguments
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
"Builder *, OperationState &state, Block *successor", [{
|
||||
state.addSuccessor(successor, {});
|
||||
"Builder *, OperationState &state, "
|
||||
"Block *successor, ArrayRef<Value *> arguments = {}", [{
|
||||
state.addSuccessor(successor, arguments);
|
||||
}]
|
||||
>
|
||||
];
|
||||
|
@ -70,7 +76,13 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
|
|||
let skipDefaultBuilders = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the branch target block.
|
||||
Block *getTarget() { return getOperation()->getSuccessor(0); }
|
||||
|
||||
/// Returns the block arguments.
|
||||
operand_range getBlockArguments() {
|
||||
return getOperation()->getSuccessorOperands(0);
|
||||
}
|
||||
}];
|
||||
|
||||
let autogenSerialization = 0;
|
||||
|
@ -105,17 +117,21 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
|||
branch-conditional-op ::= `spv.BranchConditional` ssa-use
|
||||
(`[` integer-literal, integer-literal `]`)?
|
||||
`,` successor `,` successor
|
||||
successor ::= bb-id branch-use-list?
|
||||
branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
spv.BranchConditional %condition, ^true_branch, ^false_branch
|
||||
spv.BranchConditional %condition, ^true_branch(%0: i32), ^false_branch(%1: i32)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_Bool:$condition,
|
||||
Variadic<AnyType>:$branch_arguments,
|
||||
OptionalAttr<I32ArrayAttr>:$branch_weights
|
||||
);
|
||||
|
||||
|
@ -124,12 +140,13 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
|||
let builders = [
|
||||
OpBuilder<
|
||||
"Builder *builder, OperationState &state, Value *condition, "
|
||||
"Block *trueBranch, Block *falseBranch, "
|
||||
"Optional<std::pair<uint32_t, uint32_t>> weights",
|
||||
"Block *trueBlock, ArrayRef<Value *> trueArguments, "
|
||||
"Block *falseBlock, ArrayRef<Value *> falseArguments, "
|
||||
"Optional<std::pair<uint32_t, uint32_t>> weights = {}",
|
||||
[{
|
||||
state.addOperands(condition);
|
||||
state.addSuccessor(trueBranch, {});
|
||||
state.addSuccessor(falseBranch, {});
|
||||
state.addSuccessor(trueBlock, trueArguments);
|
||||
state.addSuccessor(falseBlock, falseArguments);
|
||||
if (weights) {
|
||||
auto attr =
|
||||
builder->getI32ArrayAttr({static_cast<int32_t>(weights->first),
|
||||
|
@ -145,12 +162,57 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
|||
let autogenSerialization = 0;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Branch indices into the successor list.
|
||||
/// Branch indices into the successor list.
|
||||
enum { kTrueIndex = 0, kFalseIndex = 1 };
|
||||
|
||||
/// Returns the target block for the true branch.
|
||||
Block *getTrueBlock() { return getOperation()->getSuccessor(kTrueIndex); }
|
||||
|
||||
/// Returns the target block for the false branch.
|
||||
Block *getFalseBlock() { return getOperation()->getSuccessor(kFalseIndex); }
|
||||
|
||||
/// Returns the number of arguments to the true target block.
|
||||
unsigned getNumTrueBlockArguments() {
|
||||
return getNumSuccessorOperands(kTrueIndex);
|
||||
}
|
||||
|
||||
/// Returns the number of arguments to the false target block.
|
||||
unsigned getNumFalseBlockArguments() {
|
||||
return getNumSuccessorOperands(kFalseIndex);
|
||||
}
|
||||
|
||||
// Iterator and range support for true target block arguments.
|
||||
operand_iterator true_block_argument_begin() {
|
||||
return operand_begin() + getTrueBlockArgumentIndex();
|
||||
}
|
||||
operand_iterator true_block_argument_end() {
|
||||
return true_block_argument_begin() + getNumTrueBlockArguments();
|
||||
}
|
||||
operand_range getTrueBlockArguments() {
|
||||
return {true_block_argument_begin(), true_block_argument_end()};
|
||||
}
|
||||
|
||||
// Iterator and range support for false target block arguments.
|
||||
operand_iterator false_block_argument_begin() {
|
||||
return true_block_argument_end();
|
||||
}
|
||||
operand_iterator false_block_argument_end() {
|
||||
return false_block_argument_begin() + getNumFalseBlockArguments();
|
||||
}
|
||||
operand_range getFalseBlockArguments() {
|
||||
return {false_block_argument_begin(), false_block_argument_end()};
|
||||
}
|
||||
|
||||
private:
|
||||
/// Gets the index of the first true block argument in the operand list.
|
||||
unsigned getTrueBlockArgumentIndex() {
|
||||
return 1; // Omit the first argument, which is the condition.
|
||||
}
|
||||
|
||||
/// Gets the index of the first false block argument in the operand list.
|
||||
unsigned getFalseBlockArgumentIndex() {
|
||||
return getTrueBlockArgumentIndex() + getNumTrueBlockArguments();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -1489,8 +1489,10 @@ Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
|
|||
weights = std::make_pair(operands[3], operands[4]);
|
||||
}
|
||||
|
||||
opBuilder.create<spirv::BranchConditionalOp>(unknownLoc, condition, trueBlock,
|
||||
falseBlock, weights);
|
||||
opBuilder.create<spirv::BranchConditionalOp>(
|
||||
unknownLoc, condition, trueBlock,
|
||||
/*trueArguments=*/ArrayRef<Value *>(), falseBlock,
|
||||
/*falseArguments=*/ArrayRef<Value *>(), weights);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -22,8 +22,8 @@ spv.module "Logical" "GLSL450" {
|
|||
%val0 = spv.Load "Function" %var : i32
|
||||
// CHECK-NEXT: spv.SLessThan
|
||||
%cmp = spv.SLessThan %val0, %count : i32
|
||||
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb2, ^bb4
|
||||
spv.BranchConditional %cmp, ^body, ^merge
|
||||
// CHECK-NEXT: spv.BranchConditional %{{.*}} [1, 1], ^bb2, ^bb4
|
||||
spv.BranchConditional %cmp [1, 1], ^body, ^merge
|
||||
|
||||
// CHECK-NEXT: ^bb2:
|
||||
^body:
|
||||
|
|
|
@ -15,8 +15,8 @@ spv.module "Logical" "GLSL450" {
|
|||
// CHECK-NEXT: spv.constant 0
|
||||
// CHECK-NEXT: spv.Variable
|
||||
spv.selection {
|
||||
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
|
||||
spv.BranchConditional %cond, ^then, ^else
|
||||
// CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2
|
||||
spv.BranchConditional %cond [5, 10], ^then, ^else
|
||||
|
||||
// CHECK-NEXT: ^bb1:
|
||||
^then:
|
||||
|
|
|
@ -13,6 +13,16 @@ func @branch() -> () {
|
|||
|
||||
// -----
|
||||
|
||||
func @branch_argument() -> () {
|
||||
%zero = spv.constant 0 : i32
|
||||
// CHECK: spv.Branch ^bb1(%{{.*}}, %{{.*}} : i32, i32)
|
||||
spv.Branch ^next(%zero, %zero: i32, i32)
|
||||
^next(%arg0: i32, %arg1: i32):
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @missing_accessor() -> () {
|
||||
spv.Branch
|
||||
// expected-error @+1 {{expected block name}}
|
||||
|
@ -32,16 +42,6 @@ func @wrong_accessor_count() -> () {
|
|||
|
||||
// -----
|
||||
|
||||
func @accessor_argument_disallowed() -> () {
|
||||
%zero = spv.constant 0 : i32
|
||||
// expected-error @+1 {{requires zero operands}}
|
||||
"spv.Branch"()[^next(%zero : i32)] : () -> ()
|
||||
^next(%arg: i32):
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BranchConditional
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -60,6 +60,24 @@ func @cond_branch() -> () {
|
|||
|
||||
// -----
|
||||
|
||||
func @cond_branch_argument() -> () {
|
||||
%true = spv.constant true
|
||||
%zero = spv.constant 0 : i32
|
||||
// CHECK: spv.BranchConditional %{{.*}}, ^bb1(%{{.*}}, %{{.*}} : i32, i32), ^bb2
|
||||
spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
|
||||
^true1(%arg0: i32, %arg1: i32):
|
||||
// CHECK: spv.BranchConditional %{{.*}}, ^bb3, ^bb4(%{{.*}}, %{{.*}} : i32, i32)
|
||||
spv.BranchConditional %true, ^true2, ^false2(%zero, %zero: i32, i32)
|
||||
^false1:
|
||||
spv.Return
|
||||
^true2:
|
||||
spv.Return
|
||||
^false2(%arg3: i32, %arg4: i32):
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cond_branch_with_weights() -> () {
|
||||
%true = spv.constant true
|
||||
// CHECK: spv.BranchConditional %{{.*}} [5, 10]
|
||||
|
@ -108,18 +126,6 @@ func @wrong_accessor_count() -> () {
|
|||
|
||||
// -----
|
||||
|
||||
func @accessor_argument_disallowed() -> () {
|
||||
%true = spv.constant true
|
||||
// expected-error @+1 {{requires a single operand}}
|
||||
"spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> ()
|
||||
^one(%arg : i1):
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wrong_number_of_weights() -> () {
|
||||
%true = spv.constant true
|
||||
// expected-error @+1 {{must have exactly two branch weights}}
|
||||
|
|
Loading…
Reference in New Issue