[spirv] Add support for spv.selection

Similar to spv.loop, spv.selection is another op for modelling
SPIR-V structured control flow. It covers both OpBranchConditional
and OpSwitch with OpSelectionMerge.

Instead of having a `spv.SelectionMerge` op to directly model
selection merge instruction for indicating the merge target,
we use regions to delimit the boundary of the selection: the
merge target is the next op following the `spv.selection` op.
This way it's easier to discover all blocks belonging to
the selection and it plays nicer with the MLIR system.

PiperOrigin-RevId: 272475006
This commit is contained in:
Lei Zhang 2019-10-02 11:00:50 -07:00 committed by A. Unique TensorFlower
parent 088f4c502f
commit f294e0e513
8 changed files with 605 additions and 94 deletions

View File

@ -212,15 +212,92 @@ control flow construct. With this approach, it's easier to discover all blocks
belonging to a structured control flow construct. It is also more idiomatic to
MLIR system.
We introduce a a `spv.loop` op for structured loops. The merge targets are the
next ops following them. Inside their regions, a special terminator,
`spv._merge` is introduced for branching to the merge target.
We introduce a `spv.selection` and `spv.loop` op for structured selections and
loops, respectively. The merge targets are the next ops following them. Inside
their regions, a special terminator, `spv._merge` is introduced for branching to
the merge target.
### Selection
`spv.selection` defines a selection construct. It contains one region. The
region should contain at least two blocks: one selection header block and one
merge block.
* The selection header block should be the first block. It should contain the
`spv.BranchConditional` or `spv.Switch` op.
* The merge block should be the last block. The merge block should only
contain a `spv._merge` op. Any block can branch to the merge block for early
exit.
```
+--------------+
| header block | (may have multiple outgoing branches)
+--------------+
/ | \
...
+---------+ +---------+ +---------+
| case #0 | | case #1 | | case #2 | ... (may have branches between each other)
+---------+ +---------+ +---------+
...
\ | /
v
+-------------+
| merge block | (may have multiple incoming branches)
+-------------+
```
For example, for the given function
```c++
void loop(bool cond) {
int x = 0;
if (cond) {
x = 1;
} else {
x = 2;
}
// ...
}
```
It will be represented as
```mlir
func @selection(%cond: i1) -> () {
%zero = spv.constant 0: i32
%one = spv.constant 1: i32
%two = spv.constant 2: i32
%x = spv.Variable init(%zero) : !spv.ptr<i32, Function>
spv.selection {
spv.BranchConditional %cond, ^then, ^else
^then:
spv.Store "Function" %x, %one : i32
spv.Branch ^merge
^else:
spv.Store "Function" %x, %two : i32
spv.Branch ^merge
^merge:
spv._merge
}
// ...
}
```
### Loop
`spv.loop` defines a loop construct. It contains one region. The `spv.loop`
region should contain at least four blocks: one entry block, one loop header
block, one loop continue block, one merge block.
`spv.loop` defines a loop construct. It contains one region. The region should
contain at least four blocks: one entry block, one loop header block, one loop
continue block, one merge block.
* The entry block should be the first block and it should jump to the loop
header block, which is the second block.

View File

@ -163,6 +163,7 @@ def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual"
def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>;
def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>;
def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>;
def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>;
def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
@ -200,8 +201,9 @@ def SPV_OpcodeAttr :
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge,
SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue, SPV_OC_OpModuleProcessed
SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
SPV_OC_OpModuleProcessed
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
@ -1102,6 +1104,19 @@ def SPV_ScopeAttr :
let cppNamespace = "::mlir::spirv";
}
def SPV_SC_None : BitEnumAttrCase<"None", 0x0000>;
def SPV_SC_Flatten : BitEnumAttrCase<"Flatten", 0x0001>;
def SPV_SC_DontFlatten : BitEnumAttrCase<"DontFlatten", 0x0002>;
def SPV_SelectionControlAttr :
BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [
SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten
]> {
let returnType = "::mlir::spirv::SelectionControl";
let convertFromStorage = "static_cast<::mlir::spirv::SelectionControl>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_SC_UniformConstant : I32EnumAttrCase<"UniformConstant", 0>;
def SPV_SC_Input : I32EnumAttrCase<"Input", 1>;
def SPV_SC_Uniform : I32EnumAttrCase<"Uniform", 2>;

View File

@ -256,7 +256,7 @@ def SPV_LoopOp : SPV_Op<"loop"> {
// -----
def SPV_MergeOp : SPV_Op<"_merge", [HasParent<"LoopOp">, Terminator]> {
def SPV_MergeOp : SPV_Op<"_merge", [Terminator]> {
let summary = "A special terminator for merging a structured selection/loop.";
let description = [{
@ -334,4 +334,51 @@ def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> {
let results = (outs);
}
def SPV_SelectionOp : SPV_Op<"selection"> {
let summary = "Define a structured selection.";
let description = [{
SPIR-V can explicitly declare structured control-flow constructs using merge
instructions. These explicitly declare a header block before the control
flow diverges and a merge block where control flow subsequently converges.
These blocks delimit constructs that must nest, and can only be entered
and exited in structured ways. See "2.11. Structured Control Flow" of the
SPIR-V spec for more details.
Instead of having a `spv.SelectionMerge` op to directly model selection
merge instruction for indicating the merge target, we use regions to delimit
the boundary of the selection: the merge target is the next op following the
`spv.selection` op. This way it's easier to discover all blocks belonging to
the selection and it plays nicer with the MLIR system.
The `spv.selection` region should contain at least two blocks: one selection
header block, and one selection merge. The selection header block should be
the first block. The selection merge block should be the last block.
The merge block should only contain a `spv._merge` op.
}];
let arguments = (ins
SPV_SelectionControlAttr:$selection_control
);
let results = (outs);
let regions = (region AnyRegion:$body);
let extraClassDeclaration = [{
// Returns the selection header block.
Block *getHeaderBlock();
// Returns the selection merge block.
Block *getMergeBlock();
// Adds a selection merge block containing one spv._merge op.
void addMergeBlock();
}];
let hasOpcode = 0;
let autogenSerialization = 0;
}
#endif // SPIRV_CONTROLFLOW_OPS

View File

@ -334,6 +334,12 @@ static unsigned getBitWidth(Type type) {
llvm_unreachable("unhandled bit width computation for type");
}
/// Returns true if the given `block` only contains one `spv._merge` op.
static inline bool isMergeBlock(Block &block) {
return !block.empty() && std::next(block.begin()) == block.end() &&
isa<spirv::MergeOp>(block.front());
}
//===----------------------------------------------------------------------===//
// Common parsers and printers
//===----------------------------------------------------------------------===//
@ -1326,12 +1332,6 @@ static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
/*printBlockTerminators=*/true);
}
/// Returns true if the given `block` only contains one `spv._merge` op.
static inline bool isMergeBlock(Block &block) {
return std::next(block.begin()) == block.end() &&
isa<spirv::MergeOp>(block.front());
}
/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
/// given `dstBlock`.
static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
@ -1429,16 +1429,19 @@ static LogicalResult verify(spirv::LoopOp loopOp) {
}
Block *spirv::LoopOp::getHeaderBlock() {
assert(!body().empty() && "op region should not be empty!");
// The second block is the loop header block.
return &*std::next(body().begin());
}
Block *spirv::LoopOp::getContinueBlock() {
assert(!body().empty() && "op region should not be empty!");
// The second to last block is the loop continue block.
return &*std::prev(body().end(), 2);
}
Block *spirv::LoopOp::getMergeBlock() {
assert(!body().empty() && "op region should not be empty!");
// The last block is the loop merge block.
return &body().back();
}
@ -1451,7 +1454,7 @@ void spirv::LoopOp::addEntryAndMergeBlock() {
OpBuilder builder(mergeBlock);
// Add a spv._merge op into the merge block.
builder.create<spirv::MergeOp>(builder.getUnknownLoc());
builder.create<spirv::MergeOp>(getLoc());
}
//===----------------------------------------------------------------------===//
@ -1459,10 +1462,16 @@ void spirv::LoopOp::addEntryAndMergeBlock() {
//===----------------------------------------------------------------------===//
static LogicalResult verify(spirv::MergeOp mergeOp) {
auto *parentOp = mergeOp.getParentOp();
if (!parentOp ||
(!isa<spirv::SelectionOp>(parentOp) && !isa<spirv::LoopOp>(parentOp)))
return mergeOp.emitOpError(
"expected parent op to be 'spv.selection' or 'spv.loop'");
Block &parentLastBlock = mergeOp.getParentRegion()->back();
if (mergeOp.getOperation() != parentLastBlock.getTerminator())
return mergeOp.emitOpError(
"can only be used in the last block of 'spv.loop'");
"can only be used in the last block of 'spv.selection' or 'spv.loop'");
return success();
}
@ -1807,6 +1816,93 @@ static LogicalResult verify(spirv::SelectOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//
static ParseResult parseSelectionOp(OpAsmParser &parser,
OperationState &state) {
// TODO(antiagainst): support selection control properly
Builder builder = parser.getBuilder();
state.addAttribute("selection_control",
builder.getI32IntegerAttr(
static_cast<uint32_t>(spirv::SelectionControl::None)));
return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
/*argTypes=*/{});
}
static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
auto *op = selectionOp.getOperation();
printer << spirv::SelectionOp::getOperationName();
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
static LogicalResult verify(spirv::SelectionOp selectionOp) {
auto *op = selectionOp.getOperation();
// We need to verify that the blocks follow the following layout:
//
// +--------------+
// | header block |
// +--------------+
// / | \
// ...
//
//
// +---------+ +---------+ +---------+
// | case #0 | | case #1 | | case #2 | ...
// +---------+ +---------+ +---------+
//
//
// ...
// \ | /
// v
// +-------------+
// | merge block |
// +-------------+
auto &region = op->getRegion(0);
// Allow empty region as a degenerated case, which can come from
// optimizations.
if (region.empty())
return success();
// The last block is the merge block.
if (!isMergeBlock(region.back()))
return selectionOp.emitOpError(
"last block must be the merge block with only one 'spv._merge' op");
if (std::next(region.begin()) == region.end())
return selectionOp.emitOpError("must have a selection header block");
return success();
}
Block *spirv::SelectionOp::getHeaderBlock() {
assert(!body().empty() && "op region should not be empty!");
// The first block is the loop header block.
return &body().front();
}
Block *spirv::SelectionOp::getMergeBlock() {
assert(!body().empty() && "op region should not be empty!");
// The last block is the loop merge block.
return &body().back();
}
void spirv::SelectionOp::addMergeBlock() {
assert(body().empty() && "entry and merge block already exist");
auto *mergeBlock = new Block();
body().push_back(mergeBlock);
OpBuilder builder(mergeBlock);
// Add a spv._merge op into the merge block.
builder.create<spirv::MergeOp>(getLoc());
}
//===----------------------------------------------------------------------===//
// spv.specConstant
//===----------------------------------------------------------------------===//

View File

@ -241,12 +241,11 @@ private:
/// A struct for containing a header block's merge and continue targets.
struct BlockMergeInfo {
Block *mergeBlock;
Block *continueBlock;
Block *continueBlock; // nullptr for spv.selection
BlockMergeInfo() : mergeBlock(nullptr), continueBlock(nullptr) {}
BlockMergeInfo(Block *m, Block *c) : mergeBlock(m), continueBlock(c) {}
operator bool() const { return continueBlock && mergeBlock; }
BlockMergeInfo(Block *m, Block *c = nullptr)
: mergeBlock(m), continueBlock(c) {}
};
/// Returns the merge and continue target info for the given `block` if it is
@ -266,6 +265,9 @@ private:
/// Processes a SPIR-V OpLabel instruction with the given `operands`.
LogicalResult processLabel(ArrayRef<uint32_t> operands);
/// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`.
LogicalResult processSelectionMerge(ArrayRef<uint32_t> operands);
/// Processes a SPIR-V OpLoopMerge instruction with the given `operands`.
LogicalResult processLoopMerge(ArrayRef<uint32_t> operands);
@ -1485,6 +1487,34 @@ LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
}
if (operands.size() < 2) {
return emitError(
unknownLoc,
"OpLoopMerge must specify merge target and selection control");
}
if (static_cast<uint32_t>(spirv::LoopControl::None) != operands[1]) {
return emitError(unknownLoc,
"unimplmented OpSelectionMerge selection control: ")
<< operands[2];
}
auto *mergeBlock = getOrCreateBlock(operands[0]);
if (!blockMergeInfo.try_emplace(curBlock, mergeBlock).second) {
return emitError(
unknownLoc,
"a block cannot have more than one OpSelectionMerge instruction");
}
return success();
}
LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
return emitError(unknownLoc, "OpLoopMerge must appear in a block");
@ -1513,8 +1543,9 @@ LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
}
namespace {
/// A class for putting all blocks in a structured loop in a spv.loop op.
class LoopStructurizer {
/// A class for putting all blocks in a structured selection/loop in a
/// spv.selection/spv.loop op.
class ControlFlowStructurizer {
public:
/// Structurizes the loop at the given `headerBlock`.
///
@ -1523,15 +1554,19 @@ public:
/// the `headerBlock` will be redirected to the `mergeBlock`.
static LogicalResult structurize(Location loc, Block *headerBlock,
Block *mergeBlock, Block *continueBlock) {
return LoopStructurizer(loc, headerBlock, mergeBlock, continueBlock)
return ControlFlowStructurizer(loc, headerBlock, mergeBlock, continueBlock)
.structurizeImpl();
}
private:
LoopStructurizer(Location loc, Block *header, Block *merge, Block *cont)
ControlFlowStructurizer(Location loc, Block *header, Block *merge,
Block *cont)
: location(loc), headerBlock(header), mergeBlock(merge),
continueBlock(cont) {}
/// Creates a new spv.selection op at the beginning of the `mergeBlock`.
spirv::SelectionOp createSelectionOp();
/// Creates a new spv.loop op at the beginning of the `mergeBlock`.
spirv::LoopOp createLoopOp();
@ -1545,13 +1580,26 @@ private:
Block *headerBlock;
Block *mergeBlock;
Block *continueBlock;
Block *continueBlock; // nullptr for spv.selection
llvm::SetVector<Block *> constructBlocks;
};
} // namespace
spirv::LoopOp LoopStructurizer::createLoopOp() {
spirv::SelectionOp ControlFlowStructurizer::createSelectionOp() {
// Create a builder and set the insertion point to the beginning of the
// merge block so that the newly created SelectionOp will be inserted there.
OpBuilder builder(&mergeBlock->front());
auto control = builder.getI32IntegerAttr(
static_cast<uint32_t>(spirv::SelectionControl::None));
auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
selectionOp.addMergeBlock();
return selectionOp;
}
spirv::LoopOp ControlFlowStructurizer::createLoopOp() {
// Create a builder and set the insertion point to the beginning of the
// merge block so that the newly created LoopOp will be inserted there.
OpBuilder builder(&mergeBlock->front());
@ -1564,7 +1612,7 @@ spirv::LoopOp LoopStructurizer::createLoopOp() {
return loopOp;
}
void LoopStructurizer::collectBlocksInConstruct() {
void ControlFlowStructurizer::collectBlocksInConstruct() {
assert(constructBlocks.empty() && "expected empty constructBlocks");
// Put the header block in the work list first.
@ -1573,35 +1621,45 @@ void LoopStructurizer::collectBlocksInConstruct() {
// For each item in the work list, add its successors under conditions.
for (unsigned i = 0; i < constructBlocks.size(); ++i) {
for (auto *successor : constructBlocks[i]->getSuccessors())
if (successor != mergeBlock && successor != continueBlock &&
constructBlocks.count(successor) == 0) {
if (successor != mergeBlock && successor != continueBlock)
constructBlocks.insert(successor);
}
}
}
LogicalResult LoopStructurizer::structurizeImpl() {
auto loopOp = createLoopOp();
if (!loopOp)
LogicalResult ControlFlowStructurizer::structurizeImpl() {
Operation *op = nullptr;
bool isLoop = continueBlock != nullptr;
if (isLoop) {
if (auto loopOp = createLoopOp())
op = loopOp.getOperation();
} else {
if (auto selectionOp = createSelectionOp())
op = selectionOp.getOperation();
}
if (!op)
return failure();
Region &body = op->getRegion(0);
BlockAndValueMapping mapper;
// All references to the old merge block should be directed to the loop
// merge block in the LoopOp's region.
mapper.map(mergeBlock, &loopOp.body().back());
mapper.map(mergeBlock, &body.back());
collectBlocksInConstruct();
// Add the loop continue block at the last so it's the second to last block
// in LoopOp's region.
constructBlocks.insert(continueBlock);
if (isLoop) {
// Add the loop continue block at the last so it's the second to last block
// in LoopOp's region.
constructBlocks.insert(continueBlock);
}
// We've identified all blocks belonging to the loop's region. Now need to
// "move" them into the loop. Instead of really moving the blocks, in the
// following we copy them and remap all values and branches. This is because:
// We've identified all blocks belonging to the selection/loop's region. Now
// need to "move" them into the selection/loop. Instead of really moving the
// blocks, in the following we copy them and remap all values and branches.
// This is because:
// * Inserting a block into a region requires the block not in any region
// before. But loops can nest so we can create loop ops in a nested manner,
// which means some blocks may already be in a loop region when to be moved
// again.
// before. But selections/loops can nest so we can create selection/loop ops
// in a nested manner, which means some blocks may already be in a
// selection/loop region when to be moved again.
// * It's much trickier to fix up the branches into and out of the loop's
// region: we need to treat not-moved blocks and moved blocks differently:
// Not-moved blocks jumping to the loop header block need to jump to the
@ -1611,16 +1669,16 @@ LogicalResult LoopStructurizer::structurizeImpl() {
// We cannot use replaceAllUsesWith clearly and it's harder to follow the
// logic.
// Create a corresponding block in the LoopOp's region for each block in
// this loop construct.
OpBuilder loopBuilder(loopOp.body());
// Create a corresponding block in the SelectionOp/LoopOp's region for each
// block in this loop construct.
OpBuilder builder(body);
for (auto *block : constructBlocks) {
assert(block->getNumArguments() == 0 &&
"block in loop construct should not have arguments");
// Create an block and insert it before the loop merge block in the
// LoopOp's region.
auto *newBlock = loopBuilder.createBlock(&loopOp.body().back());
auto *newBlock = builder.createBlock(&body.back());
mapper.map(block, newBlock);
for (auto &op : *block)
@ -1636,30 +1694,30 @@ LogicalResult LoopStructurizer::structurizeImpl() {
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
succOp.set(mappedOp);
};
for (auto &block : loopOp.body()) {
for (auto &block : body) {
block.walk(remapOperands);
}
// We have created the LoopOp and "moved" all blocks belonging to the loop
// construct into its region. Next we need to fix the connections between
// this new LoopOp with existing blocks.
// We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
// the selection/loop construct into its region. Next we need to fix the
// connections between this new SelectionOp/LoopOp with existing blocks.
// All existing incoming branches should go to the merge block, where the
// LoopOp resides right now.
// SelectionOp/LoopOp resides right now.
headerBlock->replaceAllUsesWith(mergeBlock);
// The loop entry block should have a unconditional branch jumping to the
// loop header block.
loopBuilder.setInsertionPointToEnd(&loopOp.body().front());
loopBuilder.create<spirv::BranchOp>(location,
mapper.lookupOrNull(headerBlock));
// All the blocks cloned into the LoopOp's region can now be deleted.
for (auto *block : constructBlocks) {
block->clear();
block->erase();
if (isLoop) {
// The loop entry block should have a unconditional branch jumping to the
// loop header block.
builder.setInsertionPointToEnd(&body.front());
builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock));
}
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
// deleted.
for (auto *block : constructBlocks)
block->erase();
return success();
}
@ -1668,23 +1726,21 @@ LogicalResult Deserializer::structurizeControlFlow() {
while (!blockMergeInfo.empty()) {
auto *headerBlock = blockMergeInfo.begin()->first;
const auto &mergeInfo = blockMergeInfo.begin()->second;
LLVM_DEBUG(llvm::dbgs() << "[cf] header block @ " << headerBlock << "\n");
const auto &mergeInfo = blockMergeInfo.begin()->second;
auto *mergeBlock = mergeInfo.mergeBlock;
auto *continueBlock = mergeInfo.continueBlock;
LLVM_DEBUG(llvm::dbgs() << "[cf] header block @ " << headerBlock << "\n");
assert(mergeBlock && "merge block cannot be nullptr");
LLVM_DEBUG(llvm::dbgs() << "[cf] merge block @ " << mergeBlock << "\n");
if (!continueBlock) {
return emitError(unknownLoc, "structurizing selection unimplemented");
if (continueBlock) {
LLVM_DEBUG(llvm::dbgs()
<< "[cf] continue block @ " << continueBlock << "\n");
}
LLVM_DEBUG(llvm::dbgs()
<< "[cf] continue block @ " << continueBlock << "\n");
if (failed(LoopStructurizer::structurize(unknownLoc, headerBlock,
mergeBlock, continueBlock))) {
if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock,
mergeBlock, continueBlock)))
return failure();
}
blockMergeInfo.erase(headerBlock);
}
@ -1830,6 +1886,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return processBranch(operands);
case spirv::Opcode::OpBranchConditional:
return processBranchConditional(operands);
case spirv::Opcode::OpSelectionMerge:
return processSelectionMerge(operands);
case spirv::Opcode::OpLoopMerge:
return processLoopMerge(operands);
default:

View File

@ -250,6 +250,8 @@ private:
processBlock(Block *block,
llvm::function_ref<void()> actionBeforeTerminator = nullptr);
LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
LogicalResult processLoopOp(spirv::LoopOp loopOp);
LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
@ -1220,10 +1222,18 @@ Serializer::processBlock(Block *block,
namespace {
/// A pre-order depth-first vistor for processing basic blocks in a spv.loop op.
///
/// This visitor is special tailored for spv.loop block serialization to satisfy
/// SPIR-V validation rules. It should not be used as a general depth-first
/// block visitor.
class LoopBlockVisitor {
/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
/// of blocks in a function must satisfy the rule that blocks appear before all
/// blocks they dominate." This can be achieved by a pre-order CFG traversal
/// algorithm. To make the serialization output more logical and readable to
/// human, we perform depth-first CFG traversal and delay the serialization of
/// the merge block (and the continue block) until after all other blocks have
/// been processed.
///
/// This visitor is special tailored for spv.selection or spv.loop block
/// serialization to satisfy SPIR-V validation rules. It should not be used
/// as a general depth-first block visitor.
class ControlFlowBlockVisitor {
public:
using BlockHandlerType = llvm::function_ref<LogicalResult(Block *)>;
@ -1232,12 +1242,13 @@ public:
/// Skips handling the `headerBlock` and blocks in the `skipBlocks` list.
static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler,
ArrayRef<Block *> skipBlocks) {
return LoopBlockVisitor(blockHandler, skipBlocks)
return ControlFlowBlockVisitor(blockHandler, skipBlocks)
.visitHeaderBlock(headerBlock);
}
private:
LoopBlockVisitor(BlockHandlerType blockHandler, ArrayRef<Block *> skipBlocks)
ControlFlowBlockVisitor(BlockHandlerType blockHandler,
ArrayRef<Block *> skipBlocks)
: blockHandler(blockHandler),
doneBlocks(skipBlocks.begin(), skipBlocks.end()) {}
@ -1274,16 +1285,54 @@ private:
};
} // namespace
LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
// of blocks in a function must satisfy the rule that blocks appear before all
// blocks they dominate." This can be achieved by a pre-order CFG traversal
// algorithm. To make the serialization output more logical and readable to
// human, we perform depth-first CFG traversal and delay the serialization of
// the continue block and the merge block until after all other blocks have
// been processed.
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// Assign <id>s to all blocks so that branches inside the SelectionOp can
// resolve properly.
auto &body = selectionOp.body();
for (Block &block : body)
assignBlockID(&block);
// Assign <id>s to all blocks so that branchs inside the LoopOp can resolve
auto *headerBlock = selectionOp.getHeaderBlock();
auto *mergeBlock = selectionOp.getMergeBlock();
auto headerID = findBlockID(headerBlock);
auto mergeID = findBlockID(mergeBlock);
// This selection is in some MLIR block with preceding and following ops. In
// the binary format, it should reside in separate SPIR-V blocks from its
// preceding and following ops. So we need to emit unconditional branches to
// jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
// afterwards.
encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});
// Emit the selection header block, which dominates all other blocks, first.
// We need to emit an OpSelectionMerge instruction before the loop header
// block's terminator.
auto emitSelectionMerge = [&]() {
// TODO(antiagainst): properly support loop control here
encodeInstructionInto(
functions, spirv::Opcode::OpSelectionMerge,
{mergeID, static_cast<uint32_t>(spirv::LoopControl::None)});
};
if (failed(processBlock(headerBlock, emitSelectionMerge)))
return failure();
// Process all blocks with a depth-first visitor starting from the header
// block. The selection header block and merge block are skipped by this
// visitor.
auto handleBlock = [&](Block *block) { return processBlock(block); };
if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock,
{mergeBlock})))
return failure();
// There is nothing to do for the merge block in the selection, which just
// contains a spv._merge op, itself. But we need to have an OpLabel
// instruction to start a new SPIR-V block for ops following this SelectionOp.
// The block should use the <id> for the merge block.
return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID});
}
LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// Assign <id>s to all blocks so that branches inside the LoopOp can resolve
// properly. We don't need to assign for the entry block, which is just for
// satisfying MLIR region's structural requirement.
auto &body = loopOp.body();
@ -1303,7 +1352,6 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// preceding and following ops. So we need to emit unconditional branches to
// jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
// afterwards.
encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});
// Emit the loop header block, which dominates all other blocks, first. We
@ -1322,8 +1370,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// block. The loop header block, loop continue block, and loop merge block are
// skipped by this visitor and handled later in this function.
auto handleBlock = [&](Block *block) { return processBlock(block); };
if (failed(LoopBlockVisitor::visit(headerBlock, handleBlock,
{continueBlock, mergeBlock})))
if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock,
{continueBlock, mergeBlock})))
return failure();
// We have handled all other blocks. Now get to the loop continue block.
@ -1332,7 +1380,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// There is nothing to do for the merge block in the loop, which just contains
// a spv._merge op, itself. But we need to have an OpLabel instruction to
// start a new SPIR-V block for ops following this LoopOp.
// start a new SPIR-V block for ops following this LoopOp. The block should
// use the <id> for the merge block.
return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID});
}
@ -1438,6 +1487,9 @@ LogicalResult Serializer::processOperation(Operation *op) {
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
return processGlobalVariableOp(varOp);
}
if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) {
return processSelectionOp(selectionOp);
}
if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) {
return processLoopOp(loopOp);
}

View File

@ -0,0 +1,49 @@
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
spv.module "Logical" "GLSL450" {
func @selection(%cond: i1) -> () {
%zero = spv.constant 0: i32
%one = spv.constant 1: i32
%two = spv.constant 2: i32
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
// CHECK: spv.Branch ^bb1
// CHECK-NEXT: ^bb1:
// CHECK-NEXT: spv.selection
spv.selection {
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
spv.BranchConditional %cond, ^then, ^else
// CHECK-NEXT: ^bb1:
^then:
// CHECK-NEXT: spv.constant 1
// CHECK-NEXT: spv.Store
spv.Store "Function" %var, %one : i32
// CHECK-NEXT: spv.Branch ^bb3
spv.Branch ^merge
// CHECK-NEXT: ^bb2:
^else:
// CHECK-NEXT: spv.constant 2
// CHECK-NEXT: spv.Store
spv.Store "Function" %var, %two : i32
// CHECK-NEXT: spv.Branch ^bb3
spv.Branch ^merge
// CHECK-NEXT: ^bb3:
^merge:
// CHECK-NEXT: spv._merge
spv._merge
}
spv.Return
}
func @main() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @main
spv.ExecutionMode @main "LocalSize", 1, 1, 1
} attributes {
capabilities = ["Shader"]
}

View File

@ -404,16 +404,38 @@ func @only_entry_and_continue_branch_to_header() -> () {
// -----
//===----------------------------------------------------------------------===//
// spv.merge
// spv._merge
//===----------------------------------------------------------------------===//
func @merge() -> () {
// expected-error @+1 {{expects parent op 'spv.loop'}}
// expected-error @+1 {{expected parent op to be 'spv.selection' or 'spv.loop'}}
spv._merge
}
// -----
func @only_allowed_in_last_block(%cond : i1) -> () {
%zero = spv.constant 0: i32
%one = spv.constant 1: i32
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
spv.selection {
spv.BranchConditional %cond, ^then, ^merge
^then:
spv.Store "Function" %var, %one : i32
// expected-error @+1 {{can only be used in the last block of 'spv.selection' or 'spv.loop'}}
spv._merge
^merge:
spv._merge
}
spv.Return
}
// -----
func @only_allowed_in_last_block() -> () {
%true = spv.constant true
spv.loop {
@ -421,7 +443,7 @@ func @only_allowed_in_last_block() -> () {
^header:
spv.BranchConditional %true, ^body, ^merge
^body:
// expected-error @+1 {{can only be used in the last block of 'spv.loop'}}
// expected-error @+1 {{can only be used in the last block of 'spv.selection' or 'spv.loop'}}
spv._merge
^continue:
spv.Branch ^header
@ -487,3 +509,98 @@ func @value_type_mismatch() -> (f32) {
// expected-error @+1 {{return value's type ('i32') mismatch with function's result type ('f32')}}
spv.ReturnValue %0 : i32
}
// -----
//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//
func @selection(%cond: i1) -> () {
%zero = spv.constant 0: i32
%one = spv.constant 1: i32
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
// CHECK: spv.selection {
spv.selection {
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
spv.BranchConditional %cond, ^then, ^merge
// CHECK: ^bb1
^then:
spv.Store "Function" %var, %one : i32
// CHECK: spv.Branch ^bb2
spv.Branch ^merge
// CHECK: ^bb2
^merge:
// CHECK-NEXT: spv._merge
spv._merge
}
spv.Return
}
// -----
func @selection(%cond: i1) -> () {
%zero = spv.constant 0: i32
%one = spv.constant 1: i32
%two = spv.constant 2: i32
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
// CHECK: spv.selection {
spv.selection {
// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
spv.BranchConditional %cond, ^then, ^else
// CHECK: ^bb1
^then:
spv.Store "Function" %var, %one : i32
// CHECK: spv.Branch ^bb3
spv.Branch ^merge
// CHECK: ^bb2
^else:
spv.Store "Function" %var, %two : i32
// CHECK: spv.Branch ^bb3
spv.Branch ^merge
// CHECK: ^bb3
^merge:
// CHECK-NEXT: spv._merge
spv._merge
}
spv.Return
}
// -----
// CHECK-LABEL: @empty_region
func @empty_region() -> () {
// CHECK: spv.selection
spv.selection {
}
return
}
// -----
func @wrong_merge_block() -> () {
// expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}}
spv.selection {
spv.Return
}
return
}
// -----
func @missing_entry_block() -> () {
// expected-error @+1 {{must have a selection header block}}
spv.selection {
spv._merge
}
return
}