Add spv.Branch and spv.BranchConditional

This CL just covers the op definition, its parsing, printing,
and verification. (De)serialization is to be implemented
in a subsequent CL.

PiperOrigin-RevId: 266431077
This commit is contained in:
Lei Zhang 2019-08-30 12:17:21 -07:00 committed by A. Unique TensorFlower
parent 3ee3710fd1
commit 4f6c29223e
5 changed files with 371 additions and 3 deletions

View File

@ -132,6 +132,8 @@ def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
@ -154,7 +156,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpLabel, SPV_OC_OpReturn, SPV_OC_OpReturnValue
SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";

View File

@ -31,6 +31,112 @@ include "mlir/SPIRV/SPIRVBase.td"
// -----
def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
let summary = "Unconditional branch to target block.";
let description = [{
This instruction must be the last instruction in a block.
### Custom assembly form
``` {.ebnf}
branch-op ::= `spv.Branch` successor
```
For example:
```
spv.Branch ^target
```
}];
let arguments = (ins);
let results = (outs);
let builders = [
OpBuilder<
"Builder *, OperationState *state, Block *successor", [{
state->addSuccessor(successor, {});
}]
>
];
let skipDefaultBuilders = 1;
let autogenSerialization = 0;
}
// -----
def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
let summary = [{
If Condition is true, branch to true block, otherwise branch to false
block.
}];
let description = [{
Condition must be a Boolean type scalar.
Branch weights are unsigned 32-bit integer literals. There must be
either no Branch Weights or exactly two branch weights. If present, the
first is the weight for branching to True Label, and the second is the
weight for branching to False Label. The implied probability that a
branch is taken is its weight divided by the sum of the two Branch
weights. At least one weight must be non-zero. A weight of zero does not
imply a branch is dead or permit its removal; branch weights are only
hints. The two weights must not overflow a 32-bit unsigned integer when
added together.
This instruction must be the last instruction in a block.
### Custom assembly form
``` {.ebnf}
branch-conditional-op ::= `spv.BranchConditional` ssa-use
(`[` integer-literal, integer-literal `]`)?
`,` successor `,` successor
```
For example:
```
spv.BranchConditional %condition, ^true_branch, ^false_branch
```
}];
let arguments = (ins
SPV_Bool:$condition,
OptionalAttr<I32ArrayAttr>:$branch_weights
);
let results = (outs);
let builders = [
OpBuilder<
"Builder *, OperationState *state, Value *condition, "
"Block *trueBranch, Block *falseBranch, /*optional*/ArrayAttr weights",
[{
state->addOperands(condition);
state->addSuccessor(trueBranch, {});
state->addSuccessor(falseBranch, {});
state->addAttribute("branch_weights", weights);
}]
>
];
let skipDefaultBuilders = 1;
let autogenSerialization = 0;
let extraClassDeclaration = [{
// Branch indices into the successor list.
enum { kTrueIndex = 0, kFalseIndex = 1 };
}];
}
// -----
def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
let summary = "Return with no value from a function with void return type.";

View File

@ -33,6 +33,7 @@ using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kDefaultValueAttrName[] = "default_value";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kIndicesAttrName[] = "indices";
@ -486,6 +487,119 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.BranchOp
//===----------------------------------------------------------------------===//
static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *state) {
Block *dest;
SmallVector<Value *, 4> destOperands;
if (parser->parseSuccessorAndUseList(dest, destOperands))
return failure();
state->addSuccessor(dest, destOperands);
return success();
}
static void print(spirv::BranchOp branchOp, OpAsmPrinter *printer) {
*printer << spirv::BranchOp::getOperationName() << ' ';
printer->printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
}
static LogicalResult verify(spirv::BranchOp branchOp) {
auto *op = branchOp.getOperation();
if (op->getNumSuccessors() != 1)
branchOp.emitOpError("must have exactly one successor");
return success();
}
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
static ParseResult parseBranchConditionalOp(OpAsmParser *parser,
OperationState *state) {
auto &builder = parser->getBuilder();
OpAsmParser::OperandType condInfo;
Block *dest;
SmallVector<Value *, 4> destOperands;
// Parse the condition.
Type boolTy = builder.getI1Type();
if (parser->parseOperand(condInfo) ||
parser->resolveOperand(condInfo, boolTy, state->operands))
return failure();
// Parse the optional branch weights.
if (succeeded(parser->parseOptionalLSquare())) {
IntegerAttr trueWeight, falseWeight;
SmallVector<NamedAttribute, 2> weights;
auto i32Type = builder.getIntegerType(32);
if (parser->parseAttribute(trueWeight, i32Type, "weight", weights) ||
parser->parseComma() ||
parser->parseAttribute(falseWeight, i32Type, "weight", weights) ||
parser->parseRSquare())
return failure();
state->addAttribute(kBranchWeightAttrName,
builder.getArrayAttr({trueWeight, falseWeight}));
}
// Parse the true branch.
if (parser->parseComma() ||
parser->parseSuccessorAndUseList(dest, destOperands))
return failure();
state->addSuccessor(dest, destOperands);
// Parse the false branch.
destOperands.clear();
if (parser->parseComma() ||
parser->parseSuccessorAndUseList(dest, destOperands))
return failure();
state->addSuccessor(dest, destOperands);
return success();
}
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter *printer) {
*printer << spirv::BranchConditionalOp::getOperationName() << ' ';
printer->printOperand(branchOp.condition());
if (auto weights = branchOp.branch_weights()) {
*printer << " [";
mlir::interleaveComma(
weights->getValue(), printer->getStream(),
[&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
*printer << "]";
}
*printer << ", ";
printer->printSuccessorAndUseList(branchOp.getOperation(),
spirv::BranchConditionalOp::kTrueIndex);
*printer << ", ";
printer->printSuccessorAndUseList(branchOp.getOperation(),
spirv::BranchConditionalOp::kFalseIndex);
}
static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
auto *op = branchOp.getOperation();
if (op->getNumSuccessors() != 2)
return branchOp.emitOpError("must have exactly two successors");
if (auto weights = branchOp.branch_weights()) {
if (weights->getValue().size() != 2) {
return branchOp.emitOpError("must have exactly two branch weights");
}
if (llvm::all_of(*weights, [](Attribute attr) {
return attr.cast<IntegerAttr>().getValue().isNullValue();
}))
return branchOp.emitOpError("branch weights cannot both be zero");
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
@ -1093,6 +1207,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.Return
//===----------------------------------------------------------------------===//

View File

@ -1,5 +1,149 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// spv.Branch
//===----------------------------------------------------------------------===//
func @branch() -> () {
// CHECK: spv.Branch ^bb1
spv.Branch ^next
^next:
spv.Return
}
// -----
func @missing_accessor() -> () {
spv.Branch
// expected-error @+1 {{expected block name}}
}
// -----
func @wrong_accessor_count() -> () {
%true = spv.constant true
// expected-error @+1 {{must have exactly one successor}}
"spv.Branch"()[^one, ^two] : () -> ()
^one:
spv.Return
^two:
spv.Return
}
// -----
func @accessor_argument_disallowed() -> () {
%zero = spv.constant 0 : i32
// expected-error @+1 {{requires zero operands}}
"spv.Branch"()[^next(%zero : i32)] : () -> ()
^next(%arg: i32):
spv.Return
}
// -----
//===----------------------------------------------------------------------===//
// spv.BranchConditional
//===----------------------------------------------------------------------===//
func @cond_branch() -> () {
%true = spv.constant true
// CHECK: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
spv.BranchConditional %true, ^one, ^two
// CHECK: ^bb1
^one:
spv.Return
// CHECK: ^bb2
^two:
spv.Return
}
// -----
func @cond_branch_with_weights() -> () {
%true = spv.constant true
// CHECK: spv.BranchConditional %{{.*}} [5, 10]
spv.BranchConditional %true [5, 10], ^one, ^two
^one:
spv.Return
^two:
spv.Return
}
// -----
func @missing_condition() -> () {
// expected-error @+1 {{expected SSA operand}}
spv.BranchConditional ^one, ^two
^one:
spv.Return
^two:
spv.Return
}
// -----
func @wrong_condition_type() -> () {
// expected-note @+1 {{prior use here}}
%zero = spv.constant 0 : i32
// expected-error @+1 {{use of value '%zero' expects different type than prior uses: 'i1' vs 'i32'}}
spv.BranchConditional %zero, ^one, ^two
^one:
spv.Return
^two:
spv.Return
}
// -----
func @wrong_accessor_count() -> () {
%true = spv.constant true
// expected-error @+1 {{must have exactly two successors}}
"spv.BranchConditional"(%true)[^one] : (i1) -> ()
^one:
spv.Return
^two:
spv.Return
}
// -----
func @accessor_argment_disallowed() -> () {
%true = spv.constant true
// expected-error @+1 {{requires a single operand}}
"spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> ()
^one(%arg : i1):
spv.Return
^two:
spv.Return
}
// -----
func @wrong_number_of_weights() -> () {
%true = spv.constant true
// expected-error @+1 {{must have exactly two branch weights}}
"spv.BranchConditional"(%true)[^one, ^two] {branch_weights = [1 : i32, 2 : i32, 3 : i32]} : (i1) -> ()
^one:
spv.Return
^two:
spv.Return
}
// -----
func @weights_cannot_both_be_zero() -> () {
%true = spv.constant true
// expected-error @+1 {{branch weights cannot both be zero}}
spv.BranchConditional %true [0, 0], ^one, ^two
^one:
spv.Return
^two:
spv.Return
}
// -----
//===----------------------------------------------------------------------===//
// spv.Return
//===----------------------------------------------------------------------===//

View File

@ -421,14 +421,14 @@ def get_op_definition(instruction, doc, existing_info):
arguments = existing_info.get('arguments', None)
if arguments is None:
arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
arguments = '\n '.join(arguments)
arguments = ',\n '.join(arguments)
if arguments:
# Prepend and append whitespace for formatting
arguments = '\n {}\n '.format(arguments)
assembly = existing_info.get('assembly', None)
if assembly is None:
assembly = ' ``` {.ebnf}\n'\
assembly = '\n ``` {.ebnf}\n'\
' [TODO]\n'\
' ```\n\n'\
' For example:\n\n'\