[spirv] Allow block arguments on spv.Branch(Conditional)

We will use block arguments as the way to model SPIR-V OpPhi in
the SPIR-V dialect.

This CL also adds a few useful helper methods to both ops to
get the block arguments.

Also added tests for branch weight (de)serialization.

PiperOrigin-RevId: 275960797
This commit is contained in:
Lei Zhang 2019-10-21 17:31:32 -07:00 committed by A. Unique TensorFlower
parent 5f867d26b4
commit d9fe892e42
5 changed files with 106 additions and 36 deletions

View File

@ -46,23 +46,29 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
``` {.ebnf} ``` {.ebnf}
branch-op ::= `spv.Branch` successor branch-op ::= `spv.Branch` successor
successor ::= bb-id branch-use-list?
branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
``` ```
For example: For example:
``` ```
spv.Branch ^target spv.Branch ^target
spv.Branch ^target(%0, %1: i32, f32)
``` ```
}]; }];
let arguments = (ins); let arguments = (ins
Variadic<AnyType>:$block_arguments
);
let results = (outs); let results = (outs);
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *, OperationState &state, Block *successor", [{ "Builder *, OperationState &state, "
state.addSuccessor(successor, {}); "Block *successor, ArrayRef<Value *> arguments = {}", [{
state.addSuccessor(successor, arguments);
}] }]
> >
]; ];
@ -70,7 +76,13 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Returns the branch target block.
Block *getTarget() { return getOperation()->getSuccessor(0); } Block *getTarget() { return getOperation()->getSuccessor(0); }
/// Returns the block arguments.
operand_range getBlockArguments() {
return getOperation()->getSuccessorOperands(0);
}
}]; }];
let autogenSerialization = 0; let autogenSerialization = 0;
@ -105,17 +117,21 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
branch-conditional-op ::= `spv.BranchConditional` ssa-use branch-conditional-op ::= `spv.BranchConditional` ssa-use
(`[` integer-literal, integer-literal `]`)? (`[` integer-literal, integer-literal `]`)?
`,` successor `,` successor `,` successor `,` successor
successor ::= bb-id branch-use-list?
branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
``` ```
For example: For example:
``` ```
spv.BranchConditional %condition, ^true_branch, ^false_branch spv.BranchConditional %condition, ^true_branch, ^false_branch
spv.BranchConditional %condition, ^true_branch(%0: i32), ^false_branch(%1: i32)
``` ```
}]; }];
let arguments = (ins let arguments = (ins
SPV_Bool:$condition, SPV_Bool:$condition,
Variadic<AnyType>:$branch_arguments,
OptionalAttr<I32ArrayAttr>:$branch_weights OptionalAttr<I32ArrayAttr>:$branch_weights
); );
@ -124,12 +140,13 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &state, Value *condition, " "Builder *builder, OperationState &state, Value *condition, "
"Block *trueBranch, Block *falseBranch, " "Block *trueBlock, ArrayRef<Value *> trueArguments, "
"Optional<std::pair<uint32_t, uint32_t>> weights", "Block *falseBlock, ArrayRef<Value *> falseArguments, "
"Optional<std::pair<uint32_t, uint32_t>> weights = {}",
[{ [{
state.addOperands(condition); state.addOperands(condition);
state.addSuccessor(trueBranch, {}); state.addSuccessor(trueBlock, trueArguments);
state.addSuccessor(falseBranch, {}); state.addSuccessor(falseBlock, falseArguments);
if (weights) { if (weights) {
auto attr = auto attr =
builder->getI32ArrayAttr({static_cast<int32_t>(weights->first), builder->getI32ArrayAttr({static_cast<int32_t>(weights->first),
@ -145,12 +162,57 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
let autogenSerialization = 0; let autogenSerialization = 0;
let extraClassDeclaration = [{ let extraClassDeclaration = [{
// Branch indices into the successor list. /// Branch indices into the successor list.
enum { kTrueIndex = 0, kFalseIndex = 1 }; enum { kTrueIndex = 0, kFalseIndex = 1 };
/// Returns the target block for the true branch.
Block *getTrueBlock() { return getOperation()->getSuccessor(kTrueIndex); } Block *getTrueBlock() { return getOperation()->getSuccessor(kTrueIndex); }
/// Returns the target block for the false branch.
Block *getFalseBlock() { return getOperation()->getSuccessor(kFalseIndex); } Block *getFalseBlock() { return getOperation()->getSuccessor(kFalseIndex); }
/// Returns the number of arguments to the true target block.
unsigned getNumTrueBlockArguments() {
return getNumSuccessorOperands(kTrueIndex);
}
/// Returns the number of arguments to the false target block.
unsigned getNumFalseBlockArguments() {
return getNumSuccessorOperands(kFalseIndex);
}
// Iterator and range support for true target block arguments.
operand_iterator true_block_argument_begin() {
return operand_begin() + getTrueBlockArgumentIndex();
}
operand_iterator true_block_argument_end() {
return true_block_argument_begin() + getNumTrueBlockArguments();
}
operand_range getTrueBlockArguments() {
return {true_block_argument_begin(), true_block_argument_end()};
}
// Iterator and range support for false target block arguments.
operand_iterator false_block_argument_begin() {
return true_block_argument_end();
}
operand_iterator false_block_argument_end() {
return false_block_argument_begin() + getNumFalseBlockArguments();
}
operand_range getFalseBlockArguments() {
return {false_block_argument_begin(), false_block_argument_end()};
}
private:
/// Gets the index of the first true block argument in the operand list.
unsigned getTrueBlockArgumentIndex() {
return 1; // Omit the first argument, which is the condition.
}
/// Gets the index of the first false block argument in the operand list.
unsigned getFalseBlockArgumentIndex() {
return getTrueBlockArgumentIndex() + getNumTrueBlockArguments();
}
}]; }];
} }

View File

@ -1489,8 +1489,10 @@ Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
weights = std::make_pair(operands[3], operands[4]); weights = std::make_pair(operands[3], operands[4]);
} }
opBuilder.create<spirv::BranchConditionalOp>(unknownLoc, condition, trueBlock, opBuilder.create<spirv::BranchConditionalOp>(
falseBlock, weights); unknownLoc, condition, trueBlock,
/*trueArguments=*/ArrayRef<Value *>(), falseBlock,
/*falseArguments=*/ArrayRef<Value *>(), weights);
return success(); return success();
} }

View File

@ -22,8 +22,8 @@ spv.module "Logical" "GLSL450" {
%val0 = spv.Load "Function" %var : i32 %val0 = spv.Load "Function" %var : i32
// CHECK-NEXT: spv.SLessThan // CHECK-NEXT: spv.SLessThan
%cmp = spv.SLessThan %val0, %count : i32 %cmp = spv.SLessThan %val0, %count : i32
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb2, ^bb4 // CHECK-NEXT: spv.BranchConditional %{{.*}} [1, 1], ^bb2, ^bb4
spv.BranchConditional %cmp, ^body, ^merge spv.BranchConditional %cmp [1, 1], ^body, ^merge
// CHECK-NEXT: ^bb2: // CHECK-NEXT: ^bb2:
^body: ^body:

View File

@ -15,8 +15,8 @@ spv.module "Logical" "GLSL450" {
// CHECK-NEXT: spv.constant 0 // CHECK-NEXT: spv.constant 0
// CHECK-NEXT: spv.Variable // CHECK-NEXT: spv.Variable
spv.selection { spv.selection {
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2 // CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2
spv.BranchConditional %cond, ^then, ^else spv.BranchConditional %cond [5, 10], ^then, ^else
// CHECK-NEXT: ^bb1: // CHECK-NEXT: ^bb1:
^then: ^then:

View File

@ -13,6 +13,16 @@ func @branch() -> () {
// ----- // -----
func @branch_argument() -> () {
%zero = spv.constant 0 : i32
// CHECK: spv.Branch ^bb1(%{{.*}}, %{{.*}} : i32, i32)
spv.Branch ^next(%zero, %zero: i32, i32)
^next(%arg0: i32, %arg1: i32):
spv.Return
}
// -----
func @missing_accessor() -> () { func @missing_accessor() -> () {
spv.Branch spv.Branch
// expected-error @+1 {{expected block name}} // expected-error @+1 {{expected block name}}
@ -32,16 +42,6 @@ func @wrong_accessor_count() -> () {
// ----- // -----
func @accessor_argument_disallowed() -> () {
%zero = spv.constant 0 : i32
// expected-error @+1 {{requires zero operands}}
"spv.Branch"()[^next(%zero : i32)] : () -> ()
^next(%arg: i32):
spv.Return
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.BranchConditional // spv.BranchConditional
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -60,6 +60,24 @@ func @cond_branch() -> () {
// ----- // -----
func @cond_branch_argument() -> () {
%true = spv.constant true
%zero = spv.constant 0 : i32
// CHECK: spv.BranchConditional %{{.*}}, ^bb1(%{{.*}}, %{{.*}} : i32, i32), ^bb2
spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
^true1(%arg0: i32, %arg1: i32):
// CHECK: spv.BranchConditional %{{.*}}, ^bb3, ^bb4(%{{.*}}, %{{.*}} : i32, i32)
spv.BranchConditional %true, ^true2, ^false2(%zero, %zero: i32, i32)
^false1:
spv.Return
^true2:
spv.Return
^false2(%arg3: i32, %arg4: i32):
spv.Return
}
// -----
func @cond_branch_with_weights() -> () { func @cond_branch_with_weights() -> () {
%true = spv.constant true %true = spv.constant true
// CHECK: spv.BranchConditional %{{.*}} [5, 10] // CHECK: spv.BranchConditional %{{.*}} [5, 10]
@ -108,18 +126,6 @@ func @wrong_accessor_count() -> () {
// ----- // -----
func @accessor_argument_disallowed() -> () {
%true = spv.constant true
// expected-error @+1 {{requires a single operand}}
"spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> ()
^one(%arg : i1):
spv.Return
^two:
spv.Return
}
// -----
func @wrong_number_of_weights() -> () { func @wrong_number_of_weights() -> () {
%true = spv.constant true %true = spv.constant true
// expected-error @+1 {{must have exactly two branch weights}} // expected-error @+1 {{must have exactly two branch weights}}