forked from OSchip/llvm-project
Add spv.Branch and spv.BranchConditional
This CL just covers the op definition, its parsing, printing, and verification. (De)serialization is to be implemented in a subsequent CL. PiperOrigin-RevId: 266431077
This commit is contained in:
parent
3ee3710fd1
commit
4f6c29223e
|
@ -132,6 +132,8 @@ def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
|
|||
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
|
||||
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
|
||||
def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
|
||||
def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
|
||||
def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
|
||||
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
|
||||
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
|
||||
|
||||
|
@ -154,7 +156,8 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
|
||||
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
|
||||
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
|
||||
SPV_OC_OpLabel, SPV_OC_OpReturn, SPV_OC_OpReturnValue
|
||||
SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
|
||||
SPV_OC_OpReturnValue
|
||||
]> {
|
||||
let returnType = "::mlir::spirv::Opcode";
|
||||
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
|
||||
|
|
|
@ -31,6 +31,112 @@ include "mlir/SPIRV/SPIRVBase.td"
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
|
||||
let summary = "Unconditional branch to target block.";
|
||||
|
||||
let description = [{
|
||||
This instruction must be the last instruction in a block.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
branch-op ::= `spv.Branch` successor
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
spv.Branch ^target
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
"Builder *, OperationState *state, Block *successor", [{
|
||||
state->addSuccessor(successor, {});
|
||||
}]
|
||||
>
|
||||
];
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let autogenSerialization = 0;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
||||
let summary = [{
|
||||
If Condition is true, branch to true block, otherwise branch to false
|
||||
block.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Condition must be a Boolean type scalar.
|
||||
|
||||
Branch weights are unsigned 32-bit integer literals. There must be
|
||||
either no Branch Weights or exactly two branch weights. If present, the
|
||||
first is the weight for branching to True Label, and the second is the
|
||||
weight for branching to False Label. The implied probability that a
|
||||
branch is taken is its weight divided by the sum of the two Branch
|
||||
weights. At least one weight must be non-zero. A weight of zero does not
|
||||
imply a branch is dead or permit its removal; branch weights are only
|
||||
hints. The two weights must not overflow a 32-bit unsigned integer when
|
||||
added together.
|
||||
|
||||
This instruction must be the last instruction in a block.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
branch-conditional-op ::= `spv.BranchConditional` ssa-use
|
||||
(`[` integer-literal, integer-literal `]`)?
|
||||
`,` successor `,` successor
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
spv.BranchConditional %condition, ^true_branch, ^false_branch
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_Bool:$condition,
|
||||
OptionalAttr<I32ArrayAttr>:$branch_weights
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
"Builder *, OperationState *state, Value *condition, "
|
||||
"Block *trueBranch, Block *falseBranch, /*optional*/ArrayAttr weights",
|
||||
[{
|
||||
state->addOperands(condition);
|
||||
state->addSuccessor(trueBranch, {});
|
||||
state->addSuccessor(falseBranch, {});
|
||||
state->addAttribute("branch_weights", weights);
|
||||
}]
|
||||
>
|
||||
];
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let autogenSerialization = 0;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Branch indices into the successor list.
|
||||
enum { kTrueIndex = 0, kFalseIndex = 1 };
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
|
||||
let summary = "Return with no value from a function with void return type.";
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ using namespace mlir;
|
|||
|
||||
// TODO(antiagainst): generate these strings using ODS.
|
||||
static constexpr const char kAlignmentAttrName[] = "alignment";
|
||||
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
|
||||
static constexpr const char kDefaultValueAttrName[] = "default_value";
|
||||
static constexpr const char kFnNameAttrName[] = "fn";
|
||||
static constexpr const char kIndicesAttrName[] = "indices";
|
||||
|
@ -486,6 +487,119 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *state) {
|
||||
Block *dest;
|
||||
SmallVector<Value *, 4> destOperands;
|
||||
if (parser->parseSuccessorAndUseList(dest, destOperands))
|
||||
return failure();
|
||||
state->addSuccessor(dest, destOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::BranchOp branchOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::BranchOp::getOperationName() << ' ';
|
||||
printer->printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::BranchOp branchOp) {
|
||||
auto *op = branchOp.getOperation();
|
||||
if (op->getNumSuccessors() != 1)
|
||||
branchOp.emitOpError("must have exactly one successor");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BranchConditionalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseBranchConditionalOp(OpAsmParser *parser,
|
||||
OperationState *state) {
|
||||
auto &builder = parser->getBuilder();
|
||||
OpAsmParser::OperandType condInfo;
|
||||
Block *dest;
|
||||
SmallVector<Value *, 4> destOperands;
|
||||
|
||||
// Parse the condition.
|
||||
Type boolTy = builder.getI1Type();
|
||||
if (parser->parseOperand(condInfo) ||
|
||||
parser->resolveOperand(condInfo, boolTy, state->operands))
|
||||
return failure();
|
||||
|
||||
// Parse the optional branch weights.
|
||||
if (succeeded(parser->parseOptionalLSquare())) {
|
||||
IntegerAttr trueWeight, falseWeight;
|
||||
SmallVector<NamedAttribute, 2> weights;
|
||||
|
||||
auto i32Type = builder.getIntegerType(32);
|
||||
if (parser->parseAttribute(trueWeight, i32Type, "weight", weights) ||
|
||||
parser->parseComma() ||
|
||||
parser->parseAttribute(falseWeight, i32Type, "weight", weights) ||
|
||||
parser->parseRSquare())
|
||||
return failure();
|
||||
|
||||
state->addAttribute(kBranchWeightAttrName,
|
||||
builder.getArrayAttr({trueWeight, falseWeight}));
|
||||
}
|
||||
|
||||
// Parse the true branch.
|
||||
if (parser->parseComma() ||
|
||||
parser->parseSuccessorAndUseList(dest, destOperands))
|
||||
return failure();
|
||||
state->addSuccessor(dest, destOperands);
|
||||
|
||||
// Parse the false branch.
|
||||
destOperands.clear();
|
||||
if (parser->parseComma() ||
|
||||
parser->parseSuccessorAndUseList(dest, destOperands))
|
||||
return failure();
|
||||
state->addSuccessor(dest, destOperands);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::BranchConditionalOp::getOperationName() << ' ';
|
||||
printer->printOperand(branchOp.condition());
|
||||
|
||||
if (auto weights = branchOp.branch_weights()) {
|
||||
*printer << " [";
|
||||
mlir::interleaveComma(
|
||||
weights->getValue(), printer->getStream(),
|
||||
[&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
|
||||
*printer << "]";
|
||||
}
|
||||
|
||||
*printer << ", ";
|
||||
printer->printSuccessorAndUseList(branchOp.getOperation(),
|
||||
spirv::BranchConditionalOp::kTrueIndex);
|
||||
*printer << ", ";
|
||||
printer->printSuccessorAndUseList(branchOp.getOperation(),
|
||||
spirv::BranchConditionalOp::kFalseIndex);
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
|
||||
auto *op = branchOp.getOperation();
|
||||
if (op->getNumSuccessors() != 2)
|
||||
return branchOp.emitOpError("must have exactly two successors");
|
||||
|
||||
if (auto weights = branchOp.branch_weights()) {
|
||||
if (weights->getValue().size() != 2) {
|
||||
return branchOp.emitOpError("must have exactly two branch weights");
|
||||
}
|
||||
if (llvm::all_of(*weights, [](Attribute attr) {
|
||||
return attr.cast<IntegerAttr>().getValue().isNullValue();
|
||||
}))
|
||||
return branchOp.emitOpError("branch weights cannot both be zero");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.CompositeExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1093,6 +1207,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
|
|||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Return
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,5 +1,149 @@
|
|||
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Branch
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @branch() -> () {
|
||||
// CHECK: spv.Branch ^bb1
|
||||
spv.Branch ^next
|
||||
^next:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @missing_accessor() -> () {
|
||||
spv.Branch
|
||||
// expected-error @+1 {{expected block name}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wrong_accessor_count() -> () {
|
||||
%true = spv.constant true
|
||||
// expected-error @+1 {{must have exactly one successor}}
|
||||
"spv.Branch"()[^one, ^two] : () -> ()
|
||||
^one:
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @accessor_argument_disallowed() -> () {
|
||||
%zero = spv.constant 0 : i32
|
||||
// expected-error @+1 {{requires zero operands}}
|
||||
"spv.Branch"()[^next(%zero : i32)] : () -> ()
|
||||
^next(%arg: i32):
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BranchConditional
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @cond_branch() -> () {
|
||||
%true = spv.constant true
|
||||
// CHECK: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
|
||||
spv.BranchConditional %true, ^one, ^two
|
||||
// CHECK: ^bb1
|
||||
^one:
|
||||
spv.Return
|
||||
// CHECK: ^bb2
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cond_branch_with_weights() -> () {
|
||||
%true = spv.constant true
|
||||
// CHECK: spv.BranchConditional %{{.*}} [5, 10]
|
||||
spv.BranchConditional %true [5, 10], ^one, ^two
|
||||
^one:
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @missing_condition() -> () {
|
||||
// expected-error @+1 {{expected SSA operand}}
|
||||
spv.BranchConditional ^one, ^two
|
||||
^one:
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wrong_condition_type() -> () {
|
||||
// expected-note @+1 {{prior use here}}
|
||||
%zero = spv.constant 0 : i32
|
||||
// expected-error @+1 {{use of value '%zero' expects different type than prior uses: 'i1' vs 'i32'}}
|
||||
spv.BranchConditional %zero, ^one, ^two
|
||||
^one:
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wrong_accessor_count() -> () {
|
||||
%true = spv.constant true
|
||||
// expected-error @+1 {{must have exactly two successors}}
|
||||
"spv.BranchConditional"(%true)[^one] : (i1) -> ()
|
||||
^one:
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @accessor_argment_disallowed() -> () {
|
||||
%true = spv.constant true
|
||||
// expected-error @+1 {{requires a single operand}}
|
||||
"spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> ()
|
||||
^one(%arg : i1):
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wrong_number_of_weights() -> () {
|
||||
%true = spv.constant true
|
||||
// expected-error @+1 {{must have exactly two branch weights}}
|
||||
"spv.BranchConditional"(%true)[^one, ^two] {branch_weights = [1 : i32, 2 : i32, 3 : i32]} : (i1) -> ()
|
||||
^one:
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @weights_cannot_both_be_zero() -> () {
|
||||
%true = spv.constant true
|
||||
// expected-error @+1 {{branch weights cannot both be zero}}
|
||||
spv.BranchConditional %true [0, 0], ^one, ^two
|
||||
^one:
|
||||
spv.Return
|
||||
^two:
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Return
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -421,14 +421,14 @@ def get_op_definition(instruction, doc, existing_info):
|
|||
arguments = existing_info.get('arguments', None)
|
||||
if arguments is None:
|
||||
arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
|
||||
arguments = '\n '.join(arguments)
|
||||
arguments = ',\n '.join(arguments)
|
||||
if arguments:
|
||||
# Prepend and append whitespace for formatting
|
||||
arguments = '\n {}\n '.format(arguments)
|
||||
|
||||
assembly = existing_info.get('assembly', None)
|
||||
if assembly is None:
|
||||
assembly = ' ``` {.ebnf}\n'\
|
||||
assembly = '\n ``` {.ebnf}\n'\
|
||||
' [TODO]\n'\
|
||||
' ```\n\n'\
|
||||
' For example:\n\n'\
|
||||
|
|
Loading…
Reference in New Issue