[spirv] Support OpPhi using block arguments

This CL adds another control flow instruction in SPIR-V: OpPhi.
It is modelled as block arguments to be idiomatic with MLIR.
See the rationale.md doc for "Block Arguments vs PHI nodes".
Serialization and deserialization is updated to convert between
block arguments and SPIR-V OpPhi instructions.

PiperOrigin-RevId: 277161545
This commit is contained in:
Lei Zhang 2019-10-28 15:58:11 -07:00 committed by A. Unique TensorFlower
parent 66ec24d833
commit ca2538e9a7
5 changed files with 658 additions and 117 deletions

View File

@ -199,6 +199,12 @@ A SPIR-V function is defined using the builtin `func` op. `spv.module` verifies
that the functions inside it comply with SPIR-V requirements: at most one
result, no nested functions, and so on.
## Operations
Operation documentation is written in each op's Op Definition Spec using
TableGen. A markdown version of the doc can be generated using `mlir-tblgen
-gen-doc`.
## Control Flow
SPIR-V binary format uses merge instructions (`OpSelectionMerge` and
@ -385,7 +391,63 @@ func @loop(%count : i32) -> () {
}
```
## Serialization
### Block argument for Phi
There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
instructions are modelled as block arguments in the SPIR-V dialect. (See the
[Rationale][Rationale] doc for "Block Arguments vs Phi nodes".) Each block
argument corresponds to one `OpPhi` instruction in the SPIR-V binary format. For
example, for the following SPIR-V function `foo`:
```spirv
%foo = OpFunction %void None ...
%entry = OpLabel
%var = OpVariable %_ptr_Function_int Function
OpSelectionMerge %merge None
OpBranchConditional %true %true %false
%true = OpLabel
OpBranch %phi
%false = OpLabel
OpBranch %phi
%phi = OpLabel
%val = OpPhi %int %int_1 %false %int_0 %true
OpStore %var %val
OpReturn
%merge = OpLabel
OpReturn
OpFunctionEnd
```
It will be represented as:
```mlir
func @foo() -> () {
%var = spv.Variable : !spv.ptr<i32, Function>
spv.selection {
%true = spv.constant true
spv.BranchConditional %true, ^true, ^false
^true:
%zero = spv.constant 0 : i32
spv.Branch ^phi(%zero: i32)
^false:
%one = spv.constant 1 : i32
spv.Branch ^phi(%one: i32)
^phi(%arg: i32):
spv.Store "Function" %var, %arg : i32
spv.Return
^merge:
spv._merge
}
spv.Return
}
```
## Serialization and deserialization
The serialization library provides two entry points, `mlir::spirv::serialize()`
and `mlir::spirv::deserialize()`, for converting a MLIR SPIR-V module to binary
@ -399,6 +461,25 @@ the SPIR-V binary module and does not guarantee roundtrip equivalence (at least
for now). For the latter, please use the assembler/disassembler in the
[SPIRV-Tools][SPIRV-Tools] project.
A few transformations are performed in the process of serialization because of
the representational differences between SPIR-V dialect and binary format:
* Attributes on `spv.module` are emitted as their corresponding SPIR-V
instructions.
* `spv.constant`s are unified and placed in the SPIR-V binary module section
for types, constants, and global variables.
* `spv.selection`s and `spv.loop`s are emitted as basic blocks with `Op*Merge`
instructions in the header block as required by the binary format.
Similarly, a few transformations are performed during deserialization:
* Instructions for execution environment requirements will be placed as
attribues on `spv.module`.
* `OpConstant*` instructions are materialized as `spv.constant` at each use
site.
* `OpPhi` instructions are converted to block arguments.
* Structured control flow are placed inside `spv.selection` and `spv.loop`.
[SPIR-V]: https://www.khronos.org/registry/spir-v/
[ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray
[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
@ -406,3 +487,4 @@ for now). For the latter, please use the assembler/disassembler in the
[RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray
[StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure
[SPIRV-Tools]: https://github.com/KhronosGroup/SPIRV-Tools
[Rationale]: https://github.com/tensorflow/mlir/blob/master/g3doc/Rationale.md#block-arguments-vs-phi-nodes

View File

@ -166,6 +166,7 @@ def SPV_OC_OpBitwiseXor : I32EnumAttrCase<"OpBitwiseXor", 198>;
def SPV_OC_OpBitwiseAnd : I32EnumAttrCase<"OpBitwiseAnd", 199>;
def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>;
def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>;
def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>;
def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>;
def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>;
def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
@ -205,8 +206,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd,
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge,
SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpPhi,
SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
SPV_OC_OpModuleProcessed
]> {

View File

@ -41,8 +41,8 @@ using namespace mlir;
#define DEBUG_TYPE "spirv-deserialization"
// Decodes a string literal in `words` starting at `wordIndex`. Update the
// latter to point to the position in words after the string literal.
/// Decodes a string literal in `words` starting at `wordIndex`. Update the
/// latter to point to the position in words after the string literal.
static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
unsigned &wordIndex) {
StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
@ -50,11 +50,16 @@ static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
return str;
}
// Extracts the opcode from the given first word of a SPIR-V instruction.
/// Extracts the opcode from the given first word of a SPIR-V instruction.
static inline spirv::Opcode extractOpcode(uint32_t word) {
return static_cast<spirv::Opcode>(word & 0xffff);
}
/// Returns true if the given `block` is a function entry block.
static inline bool isFnEntryBlock(Block *block) {
return block->isEntryBlock() && isa_and_nonnull<FuncOp>(block->getParentOp());
}
namespace {
/// A SPIR-V module serializer.
///
@ -130,6 +135,9 @@ private:
/// them to their handler method accordingly.
LogicalResult processFunction(ArrayRef<uint32_t> operands);
/// Processes OpFunctionEnd and finalizes function. This wires up block
/// argument created from OpPhi instructions and also structurizes control
/// flow.
LogicalResult processFunctionEnd(ArrayRef<uint32_t> operands);
/// Gets the constant's attribute and type associated with the given <id>.
@ -220,6 +228,9 @@ private:
// Control flow
//===--------------------------------------------------------------------===//
/// Returns the block for the given label <id>.
Block *getBlock(uint32_t id) const { return blockMap.lookup(id); }
// In SPIR-V, structured control flow is explicitly declared using merge
// instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect,
// we use spv.selection and spv.loop to group structured control flow.
@ -242,9 +253,6 @@ private:
// block and redirect all branches to the old header block to the old
// merge block (which contains the spv.selection/spv.loop op now).
/// Returns the block for the given label <id>.
Block *getBlock(uint32_t id) const { return blockMap.lookup(id); }
/// A struct for containing a header block's merge and continue targets.
struct BlockMergeInfo {
Block *mergeBlock;
@ -255,11 +263,24 @@ private:
: mergeBlock(m), continueBlock(c) {}
};
/// Returns the merge and continue target info for the given `block` if it is
/// a header block.
BlockMergeInfo getBlockMergeInfo(Block *block) const {
return blockMergeInfo.lookup(block);
}
/// For OpPhi instructions, we use block arguments to represent them. OpPhi
/// encodes a list of (value, predecessor) pairs. At the time of handling the
/// block containing an OpPhi instruction, the predecessor block might not be
/// processed yet, also the value sent by it. So we need to defer handling
/// the block argument from the predecessors. We use the following approach:
///
/// 1. For each OpPhi instruction, add a block argument to the current block
/// in construction. Record the block argment in `valueMap` so its uses
/// can be resolved. For the list of (value, predecessor) pairs, update
/// `blockPhiInfo` for bookkeeping.
/// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each
/// block recorded there to create the proper block arguments on their
/// terminators.
/// A data structure for containing a SPIR-V block's phi info. It will be
/// represented as block argument in SPIR-V dialect.
using BlockPhiInfo =
SmallVector<uint32_t, 2>; // The result <id> of the values sent
/// Gets or creates the block corresponding to the given label <id>. The newly
/// created block will always be placed at the end of the current function.
@ -278,6 +299,13 @@ private:
/// Processes a SPIR-V OpLoopMerge instruction with the given `operands`.
LogicalResult processLoopMerge(ArrayRef<uint32_t> operands);
/// Processes a SPIR-V OpPhi instruction with the given `operands`.
LogicalResult processPhi(ArrayRef<uint32_t> operands);
/// Creates block arguments on predecessors previously recorded when handling
/// OpPhi instructions.
LogicalResult wireUpBlockArgument();
/// Extracts blocks belonging to a structured selection/loop into a
/// spv.selection/spv.loop op. This method iterates until all blocks
/// declared as selection/loop headers are handled.
@ -407,6 +435,9 @@ private:
// Header block to its merge (and continue) target mapping.
DenseMap<Block *, BlockMergeInfo> blockMergeInfo;
// Block to its phi (block argument) mapping.
DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
// Result <id> to value mapping.
DenseMap<uint32_t, Value *> valueMap;
@ -453,7 +484,8 @@ Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context)
module(createModuleOp()), opBuilder(module->body()) {}
LogicalResult Deserializer::deserialize() {
LLVM_DEBUG(llvm::dbgs() << "++ deserialization started\n");
LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");
if (failed(processHeader()))
return failure();
@ -483,7 +515,7 @@ LogicalResult Deserializer::deserialize() {
attachCapabilities();
attachExtensions();
LLVM_DEBUG(llvm::dbgs() << "++ deserialization succeeded\n");
LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n");
return success();
}
@ -748,10 +780,10 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
ArrayRef<NamedAttribute>());
curFunction = funcMap[operands[1]] = funcOp;
LLVM_DEBUG(llvm::dbgs() << "[fn] processing function " << fnName << " (type="
<< fnType << ", id=" << operands[1] << ")\n");
LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = "
<< fnType << ", id = " << operands[1] << ") --\n");
auto *entryBlock = funcOp.addEntryBlock();
LLVM_DEBUG(llvm::dbgs() << "[block] created entry block @ " << entryBlock
LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock
<< "\n");
// Parse the op argument instructions
@ -810,8 +842,8 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
}
if (opcode == spirv::Opcode::OpFunctionEnd) {
LLVM_DEBUG(llvm::dbgs()
<< "[fn] completed function '" << fnName << "' (type=" << fnType
<< ", id=" << operands[1] << ")\n");
<< "-- completed function '" << fnName << "' (type = " << fnType
<< ", id = " << operands[1] << ") --\n");
return processFunctionEnd(instOperands);
}
if (opcode != spirv::Opcode::OpLabel) {
@ -838,8 +870,8 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return failure();
}
LLVM_DEBUG(llvm::dbgs() << "[fn] completed function '" << fnName << "' (type="
<< fnType << ", id=" << operands[1] << ")\n");
LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = "
<< fnType << ", id = " << operands[1] << ") --\n");
return processFunctionEnd(instOperands);
}
@ -849,8 +881,9 @@ LogicalResult Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
}
// Wire up block arguments from OpPhi instructions.
// Put all structured control flow in spv.selection/spv.loop ops.
if (failed(structurizeControlFlow())) {
if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
return failure();
}
@ -1438,7 +1471,7 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
Block *Deserializer::getOrCreateBlock(uint32_t id) {
if (auto *block = getBlock(id)) {
LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id=" << id
LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id
<< " @ " << block << "\n");
return block;
}
@ -1447,7 +1480,7 @@ Block *Deserializer::getOrCreateBlock(uint32_t id) {
// or spv.loop or function). Create it into the function for now and sort
// out the proper place later.
auto *block = curFunction->addBlock();
LLVM_DEBUG(llvm::dbgs() << "[block] created block for id=" << id << " @ "
LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ "
<< block << "\n");
return blockMap[id] = block;
}
@ -1509,7 +1542,7 @@ LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
auto labelID = operands[0];
// We may have forward declared this block.
auto *block = getOrCreateBlock(labelID);
LLVM_DEBUG(llvm::dbgs() << "[block] populating block @ " << block << "\n");
LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n");
// If we have seen this block, make sure it was just a forward declaration.
assert(block->empty() && "re-deserialize the same block!");
@ -1574,6 +1607,37 @@ LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult Deserializer::processPhi(ArrayRef<uint32_t> operands) {
if (!curBlock) {
return emitError(unknownLoc, "OpPhi must appear in a block");
}
if (operands.size() < 4) {
return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
"and variable-parent pairs");
}
// Create a block argument for this OpPhi instruction.
Type blockArgType = getType(operands[0]);
BlockArgument *blockArg = curBlock->addArgument(blockArgType);
valueMap[operands[1]] = blockArg;
LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg
<< " id = " << operands[1] << " of type "
<< blockArgType << '\n');
// For each (value, predecessor) pair, insert the value to the predecessor's
// blockPhiInfo entry so later we can fix the block argument there.
for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
uint32_t value = operands[i];
Block *predecessor = getOrCreateBlock(operands[i + 1]);
blockPhiInfo[predecessor].push_back(value);
LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
<< " with arg id = " << value << '\n');
}
return success();
}
namespace {
/// A class for putting all blocks in a structured selection/loop in a
/// spv.selection/spv.loop op.
@ -1711,6 +1775,14 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
mapper.map(block, newBlock);
LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock
<< " from block " << block << "\n");
if (!isFnEntryBlock(block)) {
for (BlockArgument *blockArg : block->getArguments()) {
auto *newArg = newBlock->addArgument(blockArg->getType());
mapper.map(blockArg, newArg);
LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg
<< " to " << newArg);
}
}
for (auto &op : *block)
newBlock->push_back(op.clone(mapper));
@ -1758,7 +1830,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
// If the function's header block is also part of the structured control
// flow, we cannot just simply erase it because it may contain arguments
// matching the function signature and used by the cloned blocks.
if (block->isEntryBlock() && isa<FuncOp>(block->getParentOp())) {
if (isFnEntryBlock(block)) {
LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block
<< " to only contain a spv.Branch op\n");
// Still keep the function entry block for the potential block arguments,
@ -1775,29 +1847,77 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
return success();
}
LogicalResult Deserializer::wireUpBlockArgument() {
LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n");
OpBuilder::InsertionGuard guard(opBuilder);
for (const auto &info : blockPhiInfo) {
Block *block = info.first;
const BlockPhiInfo &phiInfo = info.second;
LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
// Set insertion point to before this block's terminator early because we
// may materialize ops via getValue() call.
auto *op = block->getTerminator();
opBuilder.setInsertionPoint(op);
SmallVector<Value *, 4> blockArgs;
blockArgs.reserve(phiInfo.size());
for (uint32_t valueId : phiInfo) {
if (Value *value = getValue(valueId)) {
blockArgs.push_back(value);
LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value
<< " id = " << valueId << '\n');
} else {
return emitError(unknownLoc, "OpPhi references undefined value!");
}
}
if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
// Replace the previous branch op with a new one with block arguments.
opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
blockArgs);
branchOp.erase();
} else {
return emitError(unknownLoc, "unimplemented terminator for Phi creation");
}
LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n");
LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
}
blockPhiInfo.clear();
LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n");
return success();
}
LogicalResult Deserializer::structurizeControlFlow() {
LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n");
while (!blockMergeInfo.empty()) {
auto *headerBlock = blockMergeInfo.begin()->first;
LLVM_DEBUG(llvm::dbgs() << "[cf] header block @ " << headerBlock << "\n");
for (const auto &info : blockMergeInfo) {
auto *headerBlock = info.first;
LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << "\n");
const auto &mergeInfo = blockMergeInfo.begin()->second;
const auto &mergeInfo = info.second;
auto *mergeBlock = mergeInfo.mergeBlock;
auto *continueBlock = mergeInfo.continueBlock;
assert(mergeBlock && "merge block cannot be nullptr");
LLVM_DEBUG(llvm::dbgs() << "[cf] merge block @ " << mergeBlock << "\n");
LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << "\n");
if (continueBlock) {
LLVM_DEBUG(llvm::dbgs()
<< "[cf] continue block @ " << continueBlock << "\n");
<< "[cf] continue block " << continueBlock << "\n");
}
if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock,
mergeBlock, continueBlock)))
return failure();
blockMergeInfo.erase(headerBlock);
}
blockMergeInfo.clear();
LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n");
return success();
@ -1949,6 +2069,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return processSelectionMerge(operands);
case spirv::Opcode::OpLoopMerge:
return processLoopMerge(operands);
case spirv::Opcode::OpPhi:
return processPhi(operands);
case spirv::Opcode::OpUndef:
return processUndef(operands);
default:

View File

@ -32,8 +32,11 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "spirv-serialization"
using namespace mlir;
/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
@ -49,6 +52,77 @@ LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
return success();
}
namespace {
/// A pre-order depth-first vistor for processing basic blocks.
///
/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
/// of blocks in a function must satisfy the rule that blocks appear before all
/// blocks they dominate." This can be achieved by a pre-order CFG traversal
/// algorithm. To make the serialization output more logical and readable to
/// human, we perform depth-first CFG traversal and delay the serialization of
/// the merge block and the continue block, if exists, until after all other
/// blocks have been processed.
///
/// This visitor is special tailored for SPIR-V functions, spv.selection or
/// spv.loop block serialization to satisfy SPIR-V validation rules. It should
/// not be used as a general depth-first block visitor.
class PrettyBlockOrderVisitor {
public:
using BlockHandlerType = llvm::function_ref<LogicalResult(Block *)>;
/// Visits the basic blocks starting from the given `headerBlock`'s successors
/// in pre-order depth-first manner and calls `blockHandler` on each block.
/// Skips handling blocks in the `skipBlocks` list. If `headerBlock` is also
/// in `skipBlocks` list, still handles all its successors.
static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler,
ArrayRef<Block *> skipBlocks = {}) {
return PrettyBlockOrderVisitor(blockHandler, skipBlocks)
.visitHeaderBlock(headerBlock);
}
private:
PrettyBlockOrderVisitor(BlockHandlerType blockHandler,
ArrayRef<Block *> skipBlocks)
: blockHandler(blockHandler),
doneBlocks(skipBlocks.begin(), skipBlocks.end()) {}
LogicalResult visitHeaderBlock(Block *header) {
// Skip processing the header block if requested.
if (!llvm::is_contained(doneBlocks, header)) {
if (failed(blockHandler(header)))
return failure();
doneBlocks.insert(header);
}
for (auto *successor : header->getSuccessors()) {
if (failed(visitNormalBlock(successor)))
return failure();
}
return success();
}
LogicalResult visitNormalBlock(Block *block) {
if (doneBlocks.count(block))
return success();
if (failed(blockHandler(block)))
return failure();
doneBlocks.insert(block);
for (auto *successor : block->getSuccessors()) {
if (failed(visitNormalBlock(successor)))
return failure();
}
return success();
}
BlockHandlerType blockHandler;
SmallPtrSet<Block *, 4> doneBlocks;
};
} // namespace
namespace {
/// A SPIR-V module serializer.
@ -75,6 +149,9 @@ public:
/// Collects the final SPIR-V `binary`.
void collect(SmallVectorImpl<uint32_t> &binary);
/// (For debugging) prints each value and its corresponding result <id>.
void printValueIDMap(raw_ostream &os);
private:
// Note that there are two main categories of methods in this class:
// * process*() methods are meant to fully serialize a SPIR-V module entity
@ -244,19 +321,25 @@ private:
// Control flow
//===--------------------------------------------------------------------===//
/// Returns the result <id> for the given block.
uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); }
uint32_t assignBlockID(Block *block);
/// Returns the result <id> for the given block. If no <id> has been assigned,
/// assigns the next available <id>
uint32_t getOrCreateBlockID(Block *block);
// Processes the given `block` and emits SPIR-V instructions for all ops
// inside. Does not emit OpLabel for this block if `omitLabel` is true.
// `actionBeforeTerminator` is a callback that will be invoked before handling
// the terminator op. It can be used to inject the Op*Merge instruction if
// this is a SPIR-V selection/loop header block.
/// Processes the given `block` and emits SPIR-V instructions for all ops
/// inside. Does not emit OpLabel for this block if `omitLabel` is true.
/// `actionBeforeTerminator` is a callback that will be invoked before
/// handling the terminator op. It can be used to inject the Op*Merge
/// instruction if this is a SPIR-V selection/loop header block.
LogicalResult
processBlock(Block *block, bool omitLabel = false,
llvm::function_ref<void()> actionBeforeTerminator = nullptr);
/// Emits OpPhi instructions for the given block if it has block arguments.
LogicalResult emitPhiForBlockArguments(Block *block);
LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
LogicalResult processLoopOp(spirv::LoopOp loopOp);
@ -356,6 +439,46 @@ private:
/// Map from extended instruction set name to <id>s.
llvm::StringMap<uint32_t> extendedInstSetIDMap;
/// Map from values used in OpPhi instructions to their offset in the
/// `functions` section.
///
/// When processing a block with arguments, we need to emit OpPhi
/// instructions to record the predecessor block <id>s and the values they
/// send to the block in question. But it's not guaranteed all values are
/// visited and thus assigned result <id>s. So we need this list to capture
/// the offsets into `functions` where a value is used so that we can fix it
/// up later after processing all the blocks in a function.
///
/// More concretely, say if we are visiting the following blocks:
///
/// ```mlir
/// ^phi(%arg0: i32):
/// ...
/// ^parent1:
/// ...
/// spv.Branch ^phi(%val0: i32)
/// ^parent2:
/// ...
/// spv.Branch ^phi(%val1: i32)
/// ```
///
/// When we are serializing the `^phi` block, we need to emit at the beginning
/// of the block OpPhi instructions which has the following parameters:
///
/// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1
/// id-for-%val1 id-for-^parent2
///
/// But we don't know the <id> for %val0 and %val1 yet. One way is to visit
/// all the blocks twice and use the first visit to assign an <id> to each
/// value. But it's paying the overheads just for OpPhi emission. Instead,
/// we still visit the blocks once for emssion. When we emit the OpPhi
/// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1.
/// At the same time, we record their offsets in the emitted binary (which is
/// placed inside `functions`) here. And then after emitting all blocks, we
/// replace the dummy <id> 0 with the real result <id> by overwriting
/// `functions[offset]`.
DenseMap<Value *, llvm::SmallVector<size_t, 1>> deferredPhiValues;
};
} // namespace
@ -363,6 +486,8 @@ Serializer::Serializer(spirv::ModuleOp module)
: module(module), mlirBuilder(module.getContext()) {}
LogicalResult Serializer::serialize() {
LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
if (failed(module.verify()))
return failure();
@ -378,6 +503,8 @@ LogicalResult Serializer::serialize() {
return failure();
}
}
LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
return success();
}
@ -403,6 +530,24 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
binary.append(functions.begin(), functions.end());
}
void Serializer::printValueIDMap(raw_ostream &os) {
os << "\n= Value <id> Map =\n\n";
for (auto valueIDPair : valueIDMap) {
Value *val = valueIDPair.first;
os << " " << val << " "
<< "id = " << valueIDPair.second << ' ';
if (auto *op = val->getDefiningOp()) {
os << "from op '" << op->getName() << "'";
} else if (auto *arg = dyn_cast<BlockArgument>(val)) {
Block *block = arg->getOwner();
os << "from argument of block " << block << ' ';
os << " in op '" << block->getParentOp()->getName() << "'";
}
os << '\n';
}
}
//===----------------------------------------------------------------------===//
// Module structure
//===----------------------------------------------------------------------===//
@ -564,6 +709,8 @@ Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
} // namespace
LogicalResult Serializer::processFuncOp(FuncOp op) {
LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
uint32_t fnTypeID = 0;
// Generate type of the function.
processType(op.getLoc(), op.getType(), fnTypeID);
@ -610,11 +757,24 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
return op.emitError("external function is unhandled");
}
for (auto &block : op) {
if (failed(processBlock(&block)))
return failure();
}
if (failed(PrettyBlockOrderVisitor::visit(
&op.front(), [&](Block *block) { return processBlock(block); })))
return failure();
// There might be OpPhi instructions who have value references needing to fix.
for (auto deferredValue : deferredPhiValues) {
Value *value = deferredValue.first;
uint32_t id = getValueID(value);
LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
<< " to id = " << id << '\n');
assert(id && "OpPhi references undefined value!");
for (size_t offset : deferredValue.second)
functions[offset] = id;
}
deferredPhiValues.clear();
LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
<< "' --\n");
// Insert OpFunctionEnd.
return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
}
@ -842,7 +1002,7 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
SmallVectorImpl<uint32_t> &operands) {
typeEnum = spirv::Opcode::OpTypeFunction;
assert(type.getNumResults() <= 1 &&
"Serialization supports only a single return value");
"serialization supports only a single return value");
uint32_t resultID = 0;
if (failed(processType(
loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
@ -1221,24 +1381,31 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
// Control flow
//===----------------------------------------------------------------------===//
uint32_t Serializer::assignBlockID(Block *block) {
assert(blockIDMap.lookup(block) == 0 && "block already has <id>");
uint32_t Serializer::getOrCreateBlockID(Block *block) {
if (uint32_t id = getBlockID(block))
return id;
return blockIDMap[block] = getNextID();
}
LogicalResult
Serializer::processBlock(Block *block, bool omitLabel,
llvm::function_ref<void()> actionBeforeTerminator) {
LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
LLVM_DEBUG(block->print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
if (!omitLabel) {
auto blockID = getBlockID(block);
if (blockID == 0) {
blockID = assignBlockID(block);
}
uint32_t blockID = getOrCreateBlockID(block);
LLVM_DEBUG(llvm::dbgs()
<< "[block] " << block << " (id = " << blockID << ")\n");
// Emit OpLabel for this block.
encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID});
}
// Emit OpPhi instructions for block arguments, if any.
if (failed(emitPhiForBlockArguments(block)))
return failure();
// Process each op in this block except the terminator.
for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
if (failed(processOperation(&op)))
@ -1254,78 +1421,78 @@ Serializer::processBlock(Block *block, bool omitLabel,
return success();
}
namespace {
/// A pre-order depth-first vistor for processing basic blocks in a spv.loop op.
///
/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
/// of blocks in a function must satisfy the rule that blocks appear before all
/// blocks they dominate." This can be achieved by a pre-order CFG traversal
/// algorithm. To make the serialization output more logical and readable to
/// human, we perform depth-first CFG traversal and delay the serialization of
/// the merge block (and the continue block) until after all other blocks have
/// been processed.
///
/// This visitor is special tailored for spv.selection or spv.loop block
/// serialization to satisfy SPIR-V validation rules. It should not be used
/// as a general depth-first block visitor.
class ControlFlowBlockVisitor {
public:
using BlockHandlerType = llvm::function_ref<LogicalResult(Block *)>;
/// Visits the basic blocks starting from the given `headerBlock`'s successors
/// in pre-order depth-first manner and calls `blockHandler` on each block.
/// Skips handling the `headerBlock` and blocks in the `skipBlocks` list.
static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler,
ArrayRef<Block *> skipBlocks) {
return ControlFlowBlockVisitor(blockHandler, skipBlocks)
.visitHeaderBlock(headerBlock);
}
private:
ControlFlowBlockVisitor(BlockHandlerType blockHandler,
ArrayRef<Block *> skipBlocks)
: blockHandler(blockHandler),
doneBlocks(skipBlocks.begin(), skipBlocks.end()) {}
LogicalResult visitHeaderBlock(Block *header) {
// Skip processing the header block.
doneBlocks.insert(header);
for (auto *successor : header->getSuccessors()) {
if (failed(visitNormalBlock(successor)))
return failure();
}
LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
// Nothing to do if this block has no arguments or it's the entry block, which
// always has the same arguments as the function signature.
if (block->args_empty() || block->isEntryBlock())
return success();
// If the block has arguments, we need to create SPIR-V OpPhi instructions.
// A SPIR-V OpPhi instruction is of the syntax:
// OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
// So we need to collect all predecessor blocks and the arguments they send
// to this block.
SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
for (Block *predecessor : block->getPredecessors()) {
auto *op = predecessor->getTerminator();
if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
predecessors.emplace_back(predecessor, branchOp.operand_begin());
} else {
return op->emitError("unimplemented terminator for Phi creation");
}
}
LogicalResult visitNormalBlock(Block *block) {
if (doneBlocks.count(block))
return success();
// Then create OpPhi instruction for each of the block argument.
for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
BlockArgument *arg = block->getArgument(argIndex);
if (failed(blockHandler(block)))
// Get the type <id> and result <id> for this OpPhi instruction.
uint32_t phiTypeID = 0;
if (failed(processType(arg->getLoc(), arg->getType(), phiTypeID)))
return failure();
doneBlocks.insert(block);
uint32_t phiID = getNextID();
for (auto *successor : block->getSuccessors()) {
if (failed(visitNormalBlock(successor)))
return failure();
LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
<< arg << " (id = " << phiID << ")\n");
SmallVector<uint32_t, 8> phiArgs;
phiArgs.push_back(phiTypeID);
phiArgs.push_back(phiID);
for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
Value *value = *(predecessors[predIndex].second + argIndex);
uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
<< ") value " << value << ' ');
// Each pair is a value <id> ...
uint32_t valueId = getValueID(value);
if (valueId == 0) {
// The op generating this value hasn't been visited yet so we don't have
// an <id> assigned yet. Record this to fix up later.
LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
deferredPhiValues[value].push_back(functions.size() + 1 +
phiArgs.size());
} else {
LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
}
phiArgs.push_back(valueId);
// ... and a parent block <id>.
phiArgs.push_back(predBlockId);
}
return success();
encodeInstructionInto(functions, spirv::Opcode::OpPhi, phiArgs);
valueIDMap[arg] = phiID;
}
BlockHandlerType blockHandler;
SmallPtrSet<Block *, 4> doneBlocks;
};
} // namespace
return success();
}
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// Assign <id>s to all blocks so that branches inside the SelectionOp can
// resolve properly.
auto &body = selectionOp.body();
for (Block &block : body)
assignBlockID(&block);
getOrCreateBlockID(&block);
auto *headerBlock = selectionOp.getHeaderBlock();
auto *mergeBlock = selectionOp.getMergeBlock();
@ -1353,8 +1520,8 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// block. The selection header block and merge block are skipped by this
// visitor.
auto handleBlock = [&](Block *block) { return processBlock(block); };
if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock,
{mergeBlock})))
if (failed(PrettyBlockOrderVisitor::visit(headerBlock, handleBlock,
{headerBlock, mergeBlock})))
return failure();
// There is nothing to do for the merge block in the selection, which just
@ -1371,7 +1538,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
auto &body = loopOp.body();
for (Block &block :
llvm::make_range(std::next(body.begin(), 1), body.end())) {
assignBlockID(&block);
getOrCreateBlockID(&block);
}
auto *headerBlock = loopOp.getHeaderBlock();
auto *continueBlock = loopOp.getContinueBlock();
@ -1403,8 +1570,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// block. The loop header block, loop continue block, and loop merge block are
// skipped by this visitor and handled later in this function.
auto handleBlock = [&](Block *block) { return processBlock(block); };
if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock,
{continueBlock, mergeBlock})))
if (failed(PrettyBlockOrderVisitor::visit(
headerBlock, handleBlock, {headerBlock, continueBlock, mergeBlock})))
return failure();
// We have handled all other blocks. Now get to the loop continue block.
@ -1421,8 +1588,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
LogicalResult Serializer::processBranchConditionalOp(
spirv::BranchConditionalOp condBranchOp) {
auto conditionID = getValueID(condBranchOp.condition());
auto trueLabelID = getBlockID(condBranchOp.getTrueBlock());
auto falseLabelID = getBlockID(condBranchOp.getFalseBlock());
auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
if (auto weights = condBranchOp.branch_weights()) {
@ -1436,7 +1603,7 @@ LogicalResult Serializer::processBranchConditionalOp(
LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
return encodeInstructionInto(functions, spirv::Opcode::OpBranch,
{getBlockID(branchOp.getTarget())});
{getOrCreateBlockID(branchOp.getTarget())});
}
//===----------------------------------------------------------------------===//
@ -1500,6 +1667,8 @@ Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
}
LogicalResult Serializer::processOperation(Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "[op] '" << op->getName() << "'\n");
// First dispatch the ops that do not directly mirror an instruction from
// the SPIR-V spec.
if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) {
@ -1702,6 +1871,8 @@ LogicalResult spirv::serialize(spirv::ModuleOp module,
if (failed(serializer.serialize()))
return failure();
LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs()));
serializer.collect(binary);
return success();
}

View File

@ -0,0 +1,165 @@
// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
// Test branch with one block argument
spv.module "Logical" "GLSL450" {
func @foo() -> () {
// CHECK: %[[CST:.*]] = spv.constant 0
%zero = spv.constant 0 : i32
// CHECK-NEXT: spv.Branch ^bb1(%[[CST]] : i32)
spv.Branch ^bb1(%zero : i32)
// CHECK-NEXT: ^bb1(%{{.*}}: i32):
^bb1(%arg0: i32):
spv.Return
}
func @main() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @main
} attributes {
capabilities = ["Shader"]
}
// -----
// Test branch with multiple block arguments
spv.module "Logical" "GLSL450" {
func @foo() -> () {
// CHECK: %[[ZERO:.*]] = spv.constant 0
%zero = spv.constant 0 : i32
// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1
%one = spv.constant 1.0 : f32
// CHECK-NEXT: spv.Branch ^bb1(%[[ZERO]], %[[ONE]] : i32, f32)
spv.Branch ^bb1(%zero, %one : i32, f32)
// CHECK-NEXT: ^bb1(%{{.*}}: i32, %{{.*}}: f32): // pred: ^bb0
^bb1(%arg0: i32, %arg1: f32):
spv.Return
}
func @main() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @main
} attributes {
capabilities = ["Shader"]
}
// -----
// Test using block arguments within branch
spv.module "Logical" "GLSL450" {
func @foo() -> () {
// CHECK: %[[CST0:.*]] = spv.constant 0
%zero = spv.constant 0 : i32
// CHECK-NEXT: spv.Branch ^bb1(%[[CST0]] : i32)
spv.Branch ^bb1(%zero : i32)
// CHECK-NEXT: ^bb1(%[[ARG:.*]]: i32):
^bb1(%arg0: i32):
// CHECK-NEXT: %[[ADD:.*]] = spv.IAdd %[[ARG]], %[[ARG]] : i32
%0 = spv.IAdd %arg0, %arg0 : i32
// CHECK-NEXT: %[[CST1:.*]] = spv.constant 0
// CHECK-NEXT: spv.Branch ^bb2(%[[CST1]], %[[ADD]] : i32, i32)
spv.Branch ^bb2(%zero, %0 : i32, i32)
// CHECK-NEXT: ^bb2(%{{.*}}: i32, %{{.*}}: i32):
^bb2(%arg1: i32, %arg2: i32):
spv.Return
}
func @main() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @main
} attributes {
capabilities = ["Shader"]
}
// -----
// Test block not following domination order
spv.module "Logical" "GLSL450" {
func @foo() -> () {
// CHECK: spv.Branch ^bb1
spv.Branch ^bb1
// CHECK-NEXT: ^bb1:
// CHECK-NEXT: %[[ZERO:.*]] = spv.constant 0
// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1
// CHECK-NEXT: spv.Branch ^bb2(%[[ZERO]], %[[ONE]] : i32, f32)
// CHECK-NEXT: ^bb2(%{{.*}}: i32, %{{.*}}: f32):
^bb2(%arg0: i32, %arg1: f32):
// CHECK-NEXT: spv.Return
spv.Return
// This block is reordered to follow domination order.
^bb1:
%zero = spv.constant 0 : i32
%one = spv.constant 1.0 : f32
spv.Branch ^bb2(%zero, %one : i32, f32)
}
func @main() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @main
} attributes {
capabilities = ["Shader"]
}
// -----
// Test multiple predecessors
spv.module "Logical" "GLSL450" {
func @foo() -> () {
%var = spv.Variable : !spv.ptr<i32, Function>
// CHECK: spv.selection
spv.selection {
%true = spv.constant true
// CHECK: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
spv.BranchConditional %true, ^true, ^false
// CHECK-NEXT: ^bb1:
^true:
// CHECK-NEXT: %[[ZERO:.*]] = spv.constant 0
%zero = spv.constant 0 : i32
// CHECK-NEXT: spv.Branch ^bb3(%[[ZERO]] : i32)
spv.Branch ^phi(%zero: i32)
// CHECK-NEXT: ^bb2:
^false:
// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1
%one = spv.constant 1 : i32
// CHECK-NEXT: spv.Branch ^bb3(%[[ONE]] : i32)
spv.Branch ^phi(%one: i32)
// CHECK-NEXT: ^bb3(%[[ARG:.*]]: i32):
^phi(%arg: i32):
// CHECK-NEXT: spv.Store "Function" %{{.*}}, %[[ARG]] : i32
spv.Store "Function" %var, %arg : i32
// CHECK-NEXT: spv.Return
spv.Return
// CHECK-NEXT: ^bb4:
^merge:
// CHECK-NEXT: spv._merge
spv._merge
}
spv.Return
}
func @main() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @main
} attributes {
capabilities = ["Shader"]
}