forked from OSchip/llvm-project
[spirv] Add support for spv.selection
Similar to spv.loop, spv.selection is another op for modelling SPIR-V structured control flow. It covers both OpBranchConditional and OpSwitch with OpSelectionMerge. Instead of having a `spv.SelectionMerge` op to directly model selection merge instruction for indicating the merge target, we use regions to delimit the boundary of the selection: the merge target is the next op following the `spv.selection` op. This way it's easier to discover all blocks belonging to the selection and it plays nicer with the MLIR system. PiperOrigin-RevId: 272475006
This commit is contained in:
parent
088f4c502f
commit
f294e0e513
|
@ -212,15 +212,92 @@ control flow construct. With this approach, it's easier to discover all blocks
|
|||
belonging to a structured control flow construct. It is also more idiomatic to
|
||||
MLIR system.
|
||||
|
||||
We introduce a a `spv.loop` op for structured loops. The merge targets are the
|
||||
next ops following them. Inside their regions, a special terminator,
|
||||
`spv._merge` is introduced for branching to the merge target.
|
||||
We introduce a `spv.selection` and `spv.loop` op for structured selections and
|
||||
loops, respectively. The merge targets are the next ops following them. Inside
|
||||
their regions, a special terminator, `spv._merge` is introduced for branching to
|
||||
the merge target.
|
||||
|
||||
### Selection
|
||||
|
||||
`spv.selection` defines a selection construct. It contains one region. The
|
||||
region should contain at least two blocks: one selection header block and one
|
||||
merge block.
|
||||
|
||||
* The selection header block should be the first block. It should contain the
|
||||
`spv.BranchConditional` or `spv.Switch` op.
|
||||
* The merge block should be the last block. The merge block should only
|
||||
contain a `spv._merge` op. Any block can branch to the merge block for early
|
||||
exit.
|
||||
|
||||
```
|
||||
+--------------+
|
||||
| header block | (may have multiple outgoing branches)
|
||||
+--------------+
|
||||
/ | \
|
||||
...
|
||||
|
||||
|
||||
+---------+ +---------+ +---------+
|
||||
| case #0 | | case #1 | | case #2 | ... (may have branches between each other)
|
||||
+---------+ +---------+ +---------+
|
||||
|
||||
|
||||
...
|
||||
\ | /
|
||||
v
|
||||
+-------------+
|
||||
| merge block | (may have multiple incoming branches)
|
||||
+-------------+
|
||||
```
|
||||
|
||||
For example, for the given function
|
||||
|
||||
```c++
|
||||
void loop(bool cond) {
|
||||
int x = 0;
|
||||
if (cond) {
|
||||
x = 1;
|
||||
} else {
|
||||
x = 2;
|
||||
}
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
It will be represented as
|
||||
|
||||
```mlir
|
||||
func @selection(%cond: i1) -> () {
|
||||
%zero = spv.constant 0: i32
|
||||
%one = spv.constant 1: i32
|
||||
%two = spv.constant 2: i32
|
||||
%x = spv.Variable init(%zero) : !spv.ptr<i32, Function>
|
||||
|
||||
spv.selection {
|
||||
spv.BranchConditional %cond, ^then, ^else
|
||||
|
||||
^then:
|
||||
spv.Store "Function" %x, %one : i32
|
||||
spv.Branch ^merge
|
||||
|
||||
^else:
|
||||
spv.Store "Function" %x, %two : i32
|
||||
spv.Branch ^merge
|
||||
|
||||
^merge:
|
||||
spv._merge
|
||||
}
|
||||
|
||||
// ...
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### Loop
|
||||
|
||||
`spv.loop` defines a loop construct. It contains one region. The `spv.loop`
|
||||
region should contain at least four blocks: one entry block, one loop header
|
||||
block, one loop continue block, one merge block.
|
||||
`spv.loop` defines a loop construct. It contains one region. The region should
|
||||
contain at least four blocks: one entry block, one loop header block, one loop
|
||||
continue block, one merge block.
|
||||
|
||||
* The entry block should be the first block and it should jump to the loop
|
||||
header block, which is the second block.
|
||||
|
|
|
@ -163,6 +163,7 @@ def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual"
|
|||
def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>;
|
||||
def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>;
|
||||
def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>;
|
||||
def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>;
|
||||
def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
|
||||
def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
|
||||
def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
|
||||
|
@ -200,8 +201,9 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
|
||||
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
|
||||
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge,
|
||||
SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
|
||||
SPV_OC_OpReturnValue, SPV_OC_OpModuleProcessed
|
||||
SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
|
||||
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
|
||||
SPV_OC_OpModuleProcessed
|
||||
]> {
|
||||
let returnType = "::mlir::spirv::Opcode";
|
||||
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
|
||||
|
@ -1102,6 +1104,19 @@ def SPV_ScopeAttr :
|
|||
let cppNamespace = "::mlir::spirv";
|
||||
}
|
||||
|
||||
def SPV_SC_None : BitEnumAttrCase<"None", 0x0000>;
|
||||
def SPV_SC_Flatten : BitEnumAttrCase<"Flatten", 0x0001>;
|
||||
def SPV_SC_DontFlatten : BitEnumAttrCase<"DontFlatten", 0x0002>;
|
||||
|
||||
def SPV_SelectionControlAttr :
|
||||
BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [
|
||||
SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten
|
||||
]> {
|
||||
let returnType = "::mlir::spirv::SelectionControl";
|
||||
let convertFromStorage = "static_cast<::mlir::spirv::SelectionControl>($_self.getInt())";
|
||||
let cppNamespace = "::mlir::spirv";
|
||||
}
|
||||
|
||||
def SPV_SC_UniformConstant : I32EnumAttrCase<"UniformConstant", 0>;
|
||||
def SPV_SC_Input : I32EnumAttrCase<"Input", 1>;
|
||||
def SPV_SC_Uniform : I32EnumAttrCase<"Uniform", 2>;
|
||||
|
|
|
@ -256,7 +256,7 @@ def SPV_LoopOp : SPV_Op<"loop"> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_MergeOp : SPV_Op<"_merge", [HasParent<"LoopOp">, Terminator]> {
|
||||
def SPV_MergeOp : SPV_Op<"_merge", [Terminator]> {
|
||||
let summary = "A special terminator for merging a structured selection/loop.";
|
||||
|
||||
let description = [{
|
||||
|
@ -334,4 +334,51 @@ def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> {
|
|||
let results = (outs);
|
||||
}
|
||||
|
||||
def SPV_SelectionOp : SPV_Op<"selection"> {
|
||||
let summary = "Define a structured selection.";
|
||||
|
||||
let description = [{
|
||||
SPIR-V can explicitly declare structured control-flow constructs using merge
|
||||
instructions. These explicitly declare a header block before the control
|
||||
flow diverges and a merge block where control flow subsequently converges.
|
||||
These blocks delimit constructs that must nest, and can only be entered
|
||||
and exited in structured ways. See "2.11. Structured Control Flow" of the
|
||||
SPIR-V spec for more details.
|
||||
|
||||
Instead of having a `spv.SelectionMerge` op to directly model selection
|
||||
merge instruction for indicating the merge target, we use regions to delimit
|
||||
the boundary of the selection: the merge target is the next op following the
|
||||
`spv.selection` op. This way it's easier to discover all blocks belonging to
|
||||
the selection and it plays nicer with the MLIR system.
|
||||
|
||||
The `spv.selection` region should contain at least two blocks: one selection
|
||||
header block, and one selection merge. The selection header block should be
|
||||
the first block. The selection merge block should be the last block.
|
||||
The merge block should only contain a `spv._merge` op.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_SelectionControlAttr:$selection_control
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns the selection header block.
|
||||
Block *getHeaderBlock();
|
||||
|
||||
// Returns the selection merge block.
|
||||
Block *getMergeBlock();
|
||||
|
||||
// Adds a selection merge block containing one spv._merge op.
|
||||
void addMergeBlock();
|
||||
}];
|
||||
|
||||
let hasOpcode = 0;
|
||||
|
||||
let autogenSerialization = 0;
|
||||
}
|
||||
|
||||
#endif // SPIRV_CONTROLFLOW_OPS
|
||||
|
|
|
@ -334,6 +334,12 @@ static unsigned getBitWidth(Type type) {
|
|||
llvm_unreachable("unhandled bit width computation for type");
|
||||
}
|
||||
|
||||
/// Returns true if the given `block` only contains one `spv._merge` op.
|
||||
static inline bool isMergeBlock(Block &block) {
|
||||
return !block.empty() && std::next(block.begin()) == block.end() &&
|
||||
isa<spirv::MergeOp>(block.front());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common parsers and printers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1326,12 +1332,6 @@ static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
|
|||
/*printBlockTerminators=*/true);
|
||||
}
|
||||
|
||||
/// Returns true if the given `block` only contains one `spv._merge` op.
|
||||
static inline bool isMergeBlock(Block &block) {
|
||||
return std::next(block.begin()) == block.end() &&
|
||||
isa<spirv::MergeOp>(block.front());
|
||||
}
|
||||
|
||||
/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
|
||||
/// given `dstBlock`.
|
||||
static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
|
||||
|
@ -1429,16 +1429,19 @@ static LogicalResult verify(spirv::LoopOp loopOp) {
|
|||
}
|
||||
|
||||
Block *spirv::LoopOp::getHeaderBlock() {
|
||||
assert(!body().empty() && "op region should not be empty!");
|
||||
// The second block is the loop header block.
|
||||
return &*std::next(body().begin());
|
||||
}
|
||||
|
||||
Block *spirv::LoopOp::getContinueBlock() {
|
||||
assert(!body().empty() && "op region should not be empty!");
|
||||
// The second to last block is the loop continue block.
|
||||
return &*std::prev(body().end(), 2);
|
||||
}
|
||||
|
||||
Block *spirv::LoopOp::getMergeBlock() {
|
||||
assert(!body().empty() && "op region should not be empty!");
|
||||
// The last block is the loop merge block.
|
||||
return &body().back();
|
||||
}
|
||||
|
@ -1451,7 +1454,7 @@ void spirv::LoopOp::addEntryAndMergeBlock() {
|
|||
OpBuilder builder(mergeBlock);
|
||||
|
||||
// Add a spv._merge op into the merge block.
|
||||
builder.create<spirv::MergeOp>(builder.getUnknownLoc());
|
||||
builder.create<spirv::MergeOp>(getLoc());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1459,10 +1462,16 @@ void spirv::LoopOp::addEntryAndMergeBlock() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(spirv::MergeOp mergeOp) {
|
||||
auto *parentOp = mergeOp.getParentOp();
|
||||
if (!parentOp ||
|
||||
(!isa<spirv::SelectionOp>(parentOp) && !isa<spirv::LoopOp>(parentOp)))
|
||||
return mergeOp.emitOpError(
|
||||
"expected parent op to be 'spv.selection' or 'spv.loop'");
|
||||
|
||||
Block &parentLastBlock = mergeOp.getParentRegion()->back();
|
||||
if (mergeOp.getOperation() != parentLastBlock.getTerminator())
|
||||
return mergeOp.emitOpError(
|
||||
"can only be used in the last block of 'spv.loop'");
|
||||
"can only be used in the last block of 'spv.selection' or 'spv.loop'");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1807,6 +1816,93 @@ static LogicalResult verify(spirv::SelectOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.selection
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseSelectionOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
// TODO(antiagainst): support selection control properly
|
||||
Builder builder = parser.getBuilder();
|
||||
state.addAttribute("selection_control",
|
||||
builder.getI32IntegerAttr(
|
||||
static_cast<uint32_t>(spirv::SelectionControl::None)));
|
||||
|
||||
return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
|
||||
/*argTypes=*/{});
|
||||
}
|
||||
|
||||
static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
|
||||
auto *op = selectionOp.getOperation();
|
||||
|
||||
printer << spirv::SelectionOp::getOperationName();
|
||||
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::SelectionOp selectionOp) {
|
||||
auto *op = selectionOp.getOperation();
|
||||
|
||||
// We need to verify that the blocks follow the following layout:
|
||||
//
|
||||
// +--------------+
|
||||
// | header block |
|
||||
// +--------------+
|
||||
// / | \
|
||||
// ...
|
||||
//
|
||||
//
|
||||
// +---------+ +---------+ +---------+
|
||||
// | case #0 | | case #1 | | case #2 | ...
|
||||
// +---------+ +---------+ +---------+
|
||||
//
|
||||
//
|
||||
// ...
|
||||
// \ | /
|
||||
// v
|
||||
// +-------------+
|
||||
// | merge block |
|
||||
// +-------------+
|
||||
|
||||
auto ®ion = op->getRegion(0);
|
||||
// Allow empty region as a degenerated case, which can come from
|
||||
// optimizations.
|
||||
if (region.empty())
|
||||
return success();
|
||||
|
||||
// The last block is the merge block.
|
||||
if (!isMergeBlock(region.back()))
|
||||
return selectionOp.emitOpError(
|
||||
"last block must be the merge block with only one 'spv._merge' op");
|
||||
|
||||
if (std::next(region.begin()) == region.end())
|
||||
return selectionOp.emitOpError("must have a selection header block");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Block *spirv::SelectionOp::getHeaderBlock() {
|
||||
assert(!body().empty() && "op region should not be empty!");
|
||||
// The first block is the loop header block.
|
||||
return &body().front();
|
||||
}
|
||||
|
||||
Block *spirv::SelectionOp::getMergeBlock() {
|
||||
assert(!body().empty() && "op region should not be empty!");
|
||||
// The last block is the loop merge block.
|
||||
return &body().back();
|
||||
}
|
||||
|
||||
void spirv::SelectionOp::addMergeBlock() {
|
||||
assert(body().empty() && "entry and merge block already exist");
|
||||
auto *mergeBlock = new Block();
|
||||
body().push_back(mergeBlock);
|
||||
OpBuilder builder(mergeBlock);
|
||||
|
||||
// Add a spv._merge op into the merge block.
|
||||
builder.create<spirv::MergeOp>(getLoc());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.specConstant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -241,12 +241,11 @@ private:
|
|||
/// A struct for containing a header block's merge and continue targets.
|
||||
struct BlockMergeInfo {
|
||||
Block *mergeBlock;
|
||||
Block *continueBlock;
|
||||
Block *continueBlock; // nullptr for spv.selection
|
||||
|
||||
BlockMergeInfo() : mergeBlock(nullptr), continueBlock(nullptr) {}
|
||||
BlockMergeInfo(Block *m, Block *c) : mergeBlock(m), continueBlock(c) {}
|
||||
|
||||
operator bool() const { return continueBlock && mergeBlock; }
|
||||
BlockMergeInfo(Block *m, Block *c = nullptr)
|
||||
: mergeBlock(m), continueBlock(c) {}
|
||||
};
|
||||
|
||||
/// Returns the merge and continue target info for the given `block` if it is
|
||||
|
@ -266,6 +265,9 @@ private:
|
|||
/// Processes a SPIR-V OpLabel instruction with the given `operands`.
|
||||
LogicalResult processLabel(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`.
|
||||
LogicalResult processSelectionMerge(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Processes a SPIR-V OpLoopMerge instruction with the given `operands`.
|
||||
LogicalResult processLoopMerge(ArrayRef<uint32_t> operands);
|
||||
|
||||
|
@ -1485,6 +1487,34 @@ LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
|
||||
if (!curBlock) {
|
||||
return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
|
||||
}
|
||||
|
||||
if (operands.size() < 2) {
|
||||
return emitError(
|
||||
unknownLoc,
|
||||
"OpLoopMerge must specify merge target and selection control");
|
||||
}
|
||||
|
||||
if (static_cast<uint32_t>(spirv::LoopControl::None) != operands[1]) {
|
||||
return emitError(unknownLoc,
|
||||
"unimplmented OpSelectionMerge selection control: ")
|
||||
<< operands[2];
|
||||
}
|
||||
|
||||
auto *mergeBlock = getOrCreateBlock(operands[0]);
|
||||
|
||||
if (!blockMergeInfo.try_emplace(curBlock, mergeBlock).second) {
|
||||
return emitError(
|
||||
unknownLoc,
|
||||
"a block cannot have more than one OpSelectionMerge instruction");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
|
||||
if (!curBlock) {
|
||||
return emitError(unknownLoc, "OpLoopMerge must appear in a block");
|
||||
|
@ -1513,8 +1543,9 @@ LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
/// A class for putting all blocks in a structured loop in a spv.loop op.
|
||||
class LoopStructurizer {
|
||||
/// A class for putting all blocks in a structured selection/loop in a
|
||||
/// spv.selection/spv.loop op.
|
||||
class ControlFlowStructurizer {
|
||||
public:
|
||||
/// Structurizes the loop at the given `headerBlock`.
|
||||
///
|
||||
|
@ -1523,15 +1554,19 @@ public:
|
|||
/// the `headerBlock` will be redirected to the `mergeBlock`.
|
||||
static LogicalResult structurize(Location loc, Block *headerBlock,
|
||||
Block *mergeBlock, Block *continueBlock) {
|
||||
return LoopStructurizer(loc, headerBlock, mergeBlock, continueBlock)
|
||||
return ControlFlowStructurizer(loc, headerBlock, mergeBlock, continueBlock)
|
||||
.structurizeImpl();
|
||||
}
|
||||
|
||||
private:
|
||||
LoopStructurizer(Location loc, Block *header, Block *merge, Block *cont)
|
||||
ControlFlowStructurizer(Location loc, Block *header, Block *merge,
|
||||
Block *cont)
|
||||
: location(loc), headerBlock(header), mergeBlock(merge),
|
||||
continueBlock(cont) {}
|
||||
|
||||
/// Creates a new spv.selection op at the beginning of the `mergeBlock`.
|
||||
spirv::SelectionOp createSelectionOp();
|
||||
|
||||
/// Creates a new spv.loop op at the beginning of the `mergeBlock`.
|
||||
spirv::LoopOp createLoopOp();
|
||||
|
||||
|
@ -1545,13 +1580,26 @@ private:
|
|||
|
||||
Block *headerBlock;
|
||||
Block *mergeBlock;
|
||||
Block *continueBlock;
|
||||
Block *continueBlock; // nullptr for spv.selection
|
||||
|
||||
llvm::SetVector<Block *> constructBlocks;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
spirv::LoopOp LoopStructurizer::createLoopOp() {
|
||||
spirv::SelectionOp ControlFlowStructurizer::createSelectionOp() {
|
||||
// Create a builder and set the insertion point to the beginning of the
|
||||
// merge block so that the newly created SelectionOp will be inserted there.
|
||||
OpBuilder builder(&mergeBlock->front());
|
||||
|
||||
auto control = builder.getI32IntegerAttr(
|
||||
static_cast<uint32_t>(spirv::SelectionControl::None));
|
||||
auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
|
||||
selectionOp.addMergeBlock();
|
||||
|
||||
return selectionOp;
|
||||
}
|
||||
|
||||
spirv::LoopOp ControlFlowStructurizer::createLoopOp() {
|
||||
// Create a builder and set the insertion point to the beginning of the
|
||||
// merge block so that the newly created LoopOp will be inserted there.
|
||||
OpBuilder builder(&mergeBlock->front());
|
||||
|
@ -1564,7 +1612,7 @@ spirv::LoopOp LoopStructurizer::createLoopOp() {
|
|||
return loopOp;
|
||||
}
|
||||
|
||||
void LoopStructurizer::collectBlocksInConstruct() {
|
||||
void ControlFlowStructurizer::collectBlocksInConstruct() {
|
||||
assert(constructBlocks.empty() && "expected empty constructBlocks");
|
||||
|
||||
// Put the header block in the work list first.
|
||||
|
@ -1573,35 +1621,45 @@ void LoopStructurizer::collectBlocksInConstruct() {
|
|||
// For each item in the work list, add its successors under conditions.
|
||||
for (unsigned i = 0; i < constructBlocks.size(); ++i) {
|
||||
for (auto *successor : constructBlocks[i]->getSuccessors())
|
||||
if (successor != mergeBlock && successor != continueBlock &&
|
||||
constructBlocks.count(successor) == 0) {
|
||||
if (successor != mergeBlock && successor != continueBlock)
|
||||
constructBlocks.insert(successor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult LoopStructurizer::structurizeImpl() {
|
||||
auto loopOp = createLoopOp();
|
||||
if (!loopOp)
|
||||
LogicalResult ControlFlowStructurizer::structurizeImpl() {
|
||||
Operation *op = nullptr;
|
||||
bool isLoop = continueBlock != nullptr;
|
||||
if (isLoop) {
|
||||
if (auto loopOp = createLoopOp())
|
||||
op = loopOp.getOperation();
|
||||
} else {
|
||||
if (auto selectionOp = createSelectionOp())
|
||||
op = selectionOp.getOperation();
|
||||
}
|
||||
if (!op)
|
||||
return failure();
|
||||
Region &body = op->getRegion(0);
|
||||
|
||||
BlockAndValueMapping mapper;
|
||||
// All references to the old merge block should be directed to the loop
|
||||
// merge block in the LoopOp's region.
|
||||
mapper.map(mergeBlock, &loopOp.body().back());
|
||||
mapper.map(mergeBlock, &body.back());
|
||||
|
||||
collectBlocksInConstruct();
|
||||
// Add the loop continue block at the last so it's the second to last block
|
||||
// in LoopOp's region.
|
||||
constructBlocks.insert(continueBlock);
|
||||
if (isLoop) {
|
||||
// Add the loop continue block at the last so it's the second to last block
|
||||
// in LoopOp's region.
|
||||
constructBlocks.insert(continueBlock);
|
||||
}
|
||||
|
||||
// We've identified all blocks belonging to the loop's region. Now need to
|
||||
// "move" them into the loop. Instead of really moving the blocks, in the
|
||||
// following we copy them and remap all values and branches. This is because:
|
||||
// We've identified all blocks belonging to the selection/loop's region. Now
|
||||
// need to "move" them into the selection/loop. Instead of really moving the
|
||||
// blocks, in the following we copy them and remap all values and branches.
|
||||
// This is because:
|
||||
// * Inserting a block into a region requires the block not in any region
|
||||
// before. But loops can nest so we can create loop ops in a nested manner,
|
||||
// which means some blocks may already be in a loop region when to be moved
|
||||
// again.
|
||||
// before. But selections/loops can nest so we can create selection/loop ops
|
||||
// in a nested manner, which means some blocks may already be in a
|
||||
// selection/loop region when to be moved again.
|
||||
// * It's much trickier to fix up the branches into and out of the loop's
|
||||
// region: we need to treat not-moved blocks and moved blocks differently:
|
||||
// Not-moved blocks jumping to the loop header block need to jump to the
|
||||
|
@ -1611,16 +1669,16 @@ LogicalResult LoopStructurizer::structurizeImpl() {
|
|||
// We cannot use replaceAllUsesWith clearly and it's harder to follow the
|
||||
// logic.
|
||||
|
||||
// Create a corresponding block in the LoopOp's region for each block in
|
||||
// this loop construct.
|
||||
OpBuilder loopBuilder(loopOp.body());
|
||||
// Create a corresponding block in the SelectionOp/LoopOp's region for each
|
||||
// block in this loop construct.
|
||||
OpBuilder builder(body);
|
||||
for (auto *block : constructBlocks) {
|
||||
assert(block->getNumArguments() == 0 &&
|
||||
"block in loop construct should not have arguments");
|
||||
|
||||
// Create an block and insert it before the loop merge block in the
|
||||
// LoopOp's region.
|
||||
auto *newBlock = loopBuilder.createBlock(&loopOp.body().back());
|
||||
auto *newBlock = builder.createBlock(&body.back());
|
||||
mapper.map(block, newBlock);
|
||||
|
||||
for (auto &op : *block)
|
||||
|
@ -1636,30 +1694,30 @@ LogicalResult LoopStructurizer::structurizeImpl() {
|
|||
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
|
||||
succOp.set(mappedOp);
|
||||
};
|
||||
for (auto &block : loopOp.body()) {
|
||||
for (auto &block : body) {
|
||||
block.walk(remapOperands);
|
||||
}
|
||||
|
||||
// We have created the LoopOp and "moved" all blocks belonging to the loop
|
||||
// construct into its region. Next we need to fix the connections between
|
||||
// this new LoopOp with existing blocks.
|
||||
// We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
|
||||
// the selection/loop construct into its region. Next we need to fix the
|
||||
// connections between this new SelectionOp/LoopOp with existing blocks.
|
||||
|
||||
// All existing incoming branches should go to the merge block, where the
|
||||
// LoopOp resides right now.
|
||||
// SelectionOp/LoopOp resides right now.
|
||||
headerBlock->replaceAllUsesWith(mergeBlock);
|
||||
|
||||
// The loop entry block should have a unconditional branch jumping to the
|
||||
// loop header block.
|
||||
loopBuilder.setInsertionPointToEnd(&loopOp.body().front());
|
||||
loopBuilder.create<spirv::BranchOp>(location,
|
||||
mapper.lookupOrNull(headerBlock));
|
||||
|
||||
// All the blocks cloned into the LoopOp's region can now be deleted.
|
||||
for (auto *block : constructBlocks) {
|
||||
block->clear();
|
||||
block->erase();
|
||||
if (isLoop) {
|
||||
// The loop entry block should have a unconditional branch jumping to the
|
||||
// loop header block.
|
||||
builder.setInsertionPointToEnd(&body.front());
|
||||
builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock));
|
||||
}
|
||||
|
||||
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
|
||||
// deleted.
|
||||
for (auto *block : constructBlocks)
|
||||
block->erase();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1668,23 +1726,21 @@ LogicalResult Deserializer::structurizeControlFlow() {
|
|||
|
||||
while (!blockMergeInfo.empty()) {
|
||||
auto *headerBlock = blockMergeInfo.begin()->first;
|
||||
const auto &mergeInfo = blockMergeInfo.begin()->second;
|
||||
LLVM_DEBUG(llvm::dbgs() << "[cf] header block @ " << headerBlock << "\n");
|
||||
|
||||
const auto &mergeInfo = blockMergeInfo.begin()->second;
|
||||
auto *mergeBlock = mergeInfo.mergeBlock;
|
||||
auto *continueBlock = mergeInfo.continueBlock;
|
||||
LLVM_DEBUG(llvm::dbgs() << "[cf] header block @ " << headerBlock << "\n");
|
||||
assert(mergeBlock && "merge block cannot be nullptr");
|
||||
LLVM_DEBUG(llvm::dbgs() << "[cf] merge block @ " << mergeBlock << "\n");
|
||||
if (!continueBlock) {
|
||||
return emitError(unknownLoc, "structurizing selection unimplemented");
|
||||
if (continueBlock) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "[cf] continue block @ " << continueBlock << "\n");
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "[cf] continue block @ " << continueBlock << "\n");
|
||||
|
||||
if (failed(LoopStructurizer::structurize(unknownLoc, headerBlock,
|
||||
mergeBlock, continueBlock))) {
|
||||
if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock,
|
||||
mergeBlock, continueBlock)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
blockMergeInfo.erase(headerBlock);
|
||||
}
|
||||
|
@ -1830,6 +1886,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
|||
return processBranch(operands);
|
||||
case spirv::Opcode::OpBranchConditional:
|
||||
return processBranchConditional(operands);
|
||||
case spirv::Opcode::OpSelectionMerge:
|
||||
return processSelectionMerge(operands);
|
||||
case spirv::Opcode::OpLoopMerge:
|
||||
return processLoopMerge(operands);
|
||||
default:
|
||||
|
|
|
@ -250,6 +250,8 @@ private:
|
|||
processBlock(Block *block,
|
||||
llvm::function_ref<void()> actionBeforeTerminator = nullptr);
|
||||
|
||||
LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
|
||||
|
||||
LogicalResult processLoopOp(spirv::LoopOp loopOp);
|
||||
|
||||
LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
|
||||
|
@ -1220,10 +1222,18 @@ Serializer::processBlock(Block *block,
|
|||
namespace {
|
||||
/// A pre-order depth-first vistor for processing basic blocks in a spv.loop op.
|
||||
///
|
||||
/// This visitor is special tailored for spv.loop block serialization to satisfy
|
||||
/// SPIR-V validation rules. It should not be used as a general depth-first
|
||||
/// block visitor.
|
||||
class LoopBlockVisitor {
|
||||
/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
|
||||
/// of blocks in a function must satisfy the rule that blocks appear before all
|
||||
/// blocks they dominate." This can be achieved by a pre-order CFG traversal
|
||||
/// algorithm. To make the serialization output more logical and readable to
|
||||
/// human, we perform depth-first CFG traversal and delay the serialization of
|
||||
/// the merge block (and the continue block) until after all other blocks have
|
||||
/// been processed.
|
||||
///
|
||||
/// This visitor is special tailored for spv.selection or spv.loop block
|
||||
/// serialization to satisfy SPIR-V validation rules. It should not be used
|
||||
/// as a general depth-first block visitor.
|
||||
class ControlFlowBlockVisitor {
|
||||
public:
|
||||
using BlockHandlerType = llvm::function_ref<LogicalResult(Block *)>;
|
||||
|
||||
|
@ -1232,12 +1242,13 @@ public:
|
|||
/// Skips handling the `headerBlock` and blocks in the `skipBlocks` list.
|
||||
static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler,
|
||||
ArrayRef<Block *> skipBlocks) {
|
||||
return LoopBlockVisitor(blockHandler, skipBlocks)
|
||||
return ControlFlowBlockVisitor(blockHandler, skipBlocks)
|
||||
.visitHeaderBlock(headerBlock);
|
||||
}
|
||||
|
||||
private:
|
||||
LoopBlockVisitor(BlockHandlerType blockHandler, ArrayRef<Block *> skipBlocks)
|
||||
ControlFlowBlockVisitor(BlockHandlerType blockHandler,
|
||||
ArrayRef<Block *> skipBlocks)
|
||||
: blockHandler(blockHandler),
|
||||
doneBlocks(skipBlocks.begin(), skipBlocks.end()) {}
|
||||
|
||||
|
@ -1274,16 +1285,54 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
||||
// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
|
||||
// of blocks in a function must satisfy the rule that blocks appear before all
|
||||
// blocks they dominate." This can be achieved by a pre-order CFG traversal
|
||||
// algorithm. To make the serialization output more logical and readable to
|
||||
// human, we perform depth-first CFG traversal and delay the serialization of
|
||||
// the continue block and the merge block until after all other blocks have
|
||||
// been processed.
|
||||
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
|
||||
// Assign <id>s to all blocks so that branches inside the SelectionOp can
|
||||
// resolve properly.
|
||||
auto &body = selectionOp.body();
|
||||
for (Block &block : body)
|
||||
assignBlockID(&block);
|
||||
|
||||
// Assign <id>s to all blocks so that branchs inside the LoopOp can resolve
|
||||
auto *headerBlock = selectionOp.getHeaderBlock();
|
||||
auto *mergeBlock = selectionOp.getMergeBlock();
|
||||
auto headerID = findBlockID(headerBlock);
|
||||
auto mergeID = findBlockID(mergeBlock);
|
||||
|
||||
// This selection is in some MLIR block with preceding and following ops. In
|
||||
// the binary format, it should reside in separate SPIR-V blocks from its
|
||||
// preceding and following ops. So we need to emit unconditional branches to
|
||||
// jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
|
||||
// afterwards.
|
||||
encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});
|
||||
|
||||
// Emit the selection header block, which dominates all other blocks, first.
|
||||
// We need to emit an OpSelectionMerge instruction before the loop header
|
||||
// block's terminator.
|
||||
auto emitSelectionMerge = [&]() {
|
||||
// TODO(antiagainst): properly support loop control here
|
||||
encodeInstructionInto(
|
||||
functions, spirv::Opcode::OpSelectionMerge,
|
||||
{mergeID, static_cast<uint32_t>(spirv::LoopControl::None)});
|
||||
};
|
||||
if (failed(processBlock(headerBlock, emitSelectionMerge)))
|
||||
return failure();
|
||||
|
||||
// Process all blocks with a depth-first visitor starting from the header
|
||||
// block. The selection header block and merge block are skipped by this
|
||||
// visitor.
|
||||
auto handleBlock = [&](Block *block) { return processBlock(block); };
|
||||
if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock,
|
||||
{mergeBlock})))
|
||||
return failure();
|
||||
|
||||
// There is nothing to do for the merge block in the selection, which just
|
||||
// contains a spv._merge op, itself. But we need to have an OpLabel
|
||||
// instruction to start a new SPIR-V block for ops following this SelectionOp.
|
||||
// The block should use the <id> for the merge block.
|
||||
return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID});
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
||||
// Assign <id>s to all blocks so that branches inside the LoopOp can resolve
|
||||
// properly. We don't need to assign for the entry block, which is just for
|
||||
// satisfying MLIR region's structural requirement.
|
||||
auto &body = loopOp.body();
|
||||
|
@ -1303,7 +1352,6 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
|||
// preceding and following ops. So we need to emit unconditional branches to
|
||||
// jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
|
||||
// afterwards.
|
||||
|
||||
encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});
|
||||
|
||||
// Emit the loop header block, which dominates all other blocks, first. We
|
||||
|
@ -1322,8 +1370,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
|||
// block. The loop header block, loop continue block, and loop merge block are
|
||||
// skipped by this visitor and handled later in this function.
|
||||
auto handleBlock = [&](Block *block) { return processBlock(block); };
|
||||
if (failed(LoopBlockVisitor::visit(headerBlock, handleBlock,
|
||||
{continueBlock, mergeBlock})))
|
||||
if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock,
|
||||
{continueBlock, mergeBlock})))
|
||||
return failure();
|
||||
|
||||
// We have handled all other blocks. Now get to the loop continue block.
|
||||
|
@ -1332,7 +1380,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
|||
|
||||
// There is nothing to do for the merge block in the loop, which just contains
|
||||
// a spv._merge op, itself. But we need to have an OpLabel instruction to
|
||||
// start a new SPIR-V block for ops following this LoopOp.
|
||||
// start a new SPIR-V block for ops following this LoopOp. The block should
|
||||
// use the <id> for the merge block.
|
||||
return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID});
|
||||
}
|
||||
|
||||
|
@ -1438,6 +1487,9 @@ LogicalResult Serializer::processOperation(Operation *op) {
|
|||
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
|
||||
return processGlobalVariableOp(varOp);
|
||||
}
|
||||
if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) {
|
||||
return processSelectionOp(selectionOp);
|
||||
}
|
||||
if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) {
|
||||
return processLoopOp(loopOp);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @selection(%cond: i1) -> () {
|
||||
%zero = spv.constant 0: i32
|
||||
%one = spv.constant 1: i32
|
||||
%two = spv.constant 2: i32
|
||||
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
|
||||
|
||||
// CHECK: spv.Branch ^bb1
|
||||
// CHECK-NEXT: ^bb1:
|
||||
// CHECK-NEXT: spv.selection
|
||||
spv.selection {
|
||||
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
|
||||
spv.BranchConditional %cond, ^then, ^else
|
||||
|
||||
// CHECK-NEXT: ^bb1:
|
||||
^then:
|
||||
// CHECK-NEXT: spv.constant 1
|
||||
// CHECK-NEXT: spv.Store
|
||||
spv.Store "Function" %var, %one : i32
|
||||
// CHECK-NEXT: spv.Branch ^bb3
|
||||
spv.Branch ^merge
|
||||
|
||||
// CHECK-NEXT: ^bb2:
|
||||
^else:
|
||||
// CHECK-NEXT: spv.constant 2
|
||||
// CHECK-NEXT: spv.Store
|
||||
spv.Store "Function" %var, %two : i32
|
||||
// CHECK-NEXT: spv.Branch ^bb3
|
||||
spv.Branch ^merge
|
||||
|
||||
// CHECK-NEXT: ^bb3:
|
||||
^merge:
|
||||
// CHECK-NEXT: spv._merge
|
||||
spv._merge
|
||||
}
|
||||
|
||||
spv.Return
|
||||
}
|
||||
|
||||
func @main() -> () {
|
||||
spv.Return
|
||||
}
|
||||
spv.EntryPoint "GLCompute" @main
|
||||
spv.ExecutionMode @main "LocalSize", 1, 1, 1
|
||||
} attributes {
|
||||
capabilities = ["Shader"]
|
||||
}
|
|
@ -404,16 +404,38 @@ func @only_entry_and_continue_branch_to_header() -> () {
|
|||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.merge
|
||||
// spv._merge
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @merge() -> () {
|
||||
// expected-error @+1 {{expects parent op 'spv.loop'}}
|
||||
// expected-error @+1 {{expected parent op to be 'spv.selection' or 'spv.loop'}}
|
||||
spv._merge
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @only_allowed_in_last_block(%cond : i1) -> () {
|
||||
%zero = spv.constant 0: i32
|
||||
%one = spv.constant 1: i32
|
||||
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
|
||||
|
||||
spv.selection {
|
||||
spv.BranchConditional %cond, ^then, ^merge
|
||||
|
||||
^then:
|
||||
spv.Store "Function" %var, %one : i32
|
||||
// expected-error @+1 {{can only be used in the last block of 'spv.selection' or 'spv.loop'}}
|
||||
spv._merge
|
||||
|
||||
^merge:
|
||||
spv._merge
|
||||
}
|
||||
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @only_allowed_in_last_block() -> () {
|
||||
%true = spv.constant true
|
||||
spv.loop {
|
||||
|
@ -421,7 +443,7 @@ func @only_allowed_in_last_block() -> () {
|
|||
^header:
|
||||
spv.BranchConditional %true, ^body, ^merge
|
||||
^body:
|
||||
// expected-error @+1 {{can only be used in the last block of 'spv.loop'}}
|
||||
// expected-error @+1 {{can only be used in the last block of 'spv.selection' or 'spv.loop'}}
|
||||
spv._merge
|
||||
^continue:
|
||||
spv.Branch ^header
|
||||
|
@ -487,3 +509,98 @@ func @value_type_mismatch() -> (f32) {
|
|||
// expected-error @+1 {{return value's type ('i32') mismatch with function's result type ('f32')}}
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.selection
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @selection(%cond: i1) -> () {
|
||||
%zero = spv.constant 0: i32
|
||||
%one = spv.constant 1: i32
|
||||
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
|
||||
|
||||
// CHECK: spv.selection {
|
||||
spv.selection {
|
||||
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
|
||||
spv.BranchConditional %cond, ^then, ^merge
|
||||
|
||||
// CHECK: ^bb1
|
||||
^then:
|
||||
spv.Store "Function" %var, %one : i32
|
||||
// CHECK: spv.Branch ^bb2
|
||||
spv.Branch ^merge
|
||||
|
||||
// CHECK: ^bb2
|
||||
^merge:
|
||||
// CHECK-NEXT: spv._merge
|
||||
spv._merge
|
||||
}
|
||||
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @selection(%cond: i1) -> () {
|
||||
%zero = spv.constant 0: i32
|
||||
%one = spv.constant 1: i32
|
||||
%two = spv.constant 2: i32
|
||||
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
|
||||
|
||||
// CHECK: spv.selection {
|
||||
spv.selection {
|
||||
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
|
||||
spv.BranchConditional %cond, ^then, ^else
|
||||
|
||||
// CHECK: ^bb1
|
||||
^then:
|
||||
spv.Store "Function" %var, %one : i32
|
||||
// CHECK: spv.Branch ^bb3
|
||||
spv.Branch ^merge
|
||||
|
||||
// CHECK: ^bb2
|
||||
^else:
|
||||
spv.Store "Function" %var, %two : i32
|
||||
// CHECK: spv.Branch ^bb3
|
||||
spv.Branch ^merge
|
||||
|
||||
// CHECK: ^bb3
|
||||
^merge:
|
||||
// CHECK-NEXT: spv._merge
|
||||
spv._merge
|
||||
}
|
||||
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @empty_region
|
||||
func @empty_region() -> () {
|
||||
// CHECK: spv.selection
|
||||
spv.selection {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wrong_merge_block() -> () {
|
||||
// expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}}
|
||||
spv.selection {
|
||||
spv.Return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @missing_entry_block() -> () {
|
||||
// expected-error @+1 {{must have a selection header block}}
|
||||
spv.selection {
|
||||
spv._merge
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue