From ca2538e9a749deeedb6c790d85d811133e88ad91 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 28 Oct 2019 15:58:11 -0700 Subject: [PATCH] [spirv] Support OpPhi using block arguments This CL adds another control flow instruction in SPIR-V: OpPhi. It is modelled as block arguments to be idiomatic with MLIR. See the rationale.md doc for "Block Arguments vs PHI nodes". Serialization and deserialization is updated to convert between block arguments and SPIR-V OpPhi instructions. PiperOrigin-RevId: 277161545 --- mlir/g3doc/Dialects/SPIR-V.md | 84 ++++- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 5 +- .../SPIRV/Serialization/Deserializer.cpp | 188 ++++++++-- .../SPIRV/Serialization/Serializer.cpp | 333 +++++++++++++----- .../test/Dialect/SPIRV/Serialization/phi.mlir | 165 +++++++++ 5 files changed, 658 insertions(+), 117 deletions(-) create mode 100644 mlir/test/Dialect/SPIRV/Serialization/phi.mlir diff --git a/mlir/g3doc/Dialects/SPIR-V.md b/mlir/g3doc/Dialects/SPIR-V.md index da6c80efb981..82922de6d11e 100644 --- a/mlir/g3doc/Dialects/SPIR-V.md +++ b/mlir/g3doc/Dialects/SPIR-V.md @@ -199,6 +199,12 @@ A SPIR-V function is defined using the builtin `func` op. `spv.module` verifies that the functions inside it comply with SPIR-V requirements: at most one result, no nested functions, and so on. +## Operations + +Operation documentation is written in each op's Op Definition Spec using +TableGen. A markdown version of the doc can be generated using `mlir-tblgen +-gen-doc`. + ## Control Flow SPIR-V binary format uses merge instructions (`OpSelectionMerge` and @@ -385,7 +391,63 @@ func @loop(%count : i32) -> () { } ``` -## Serialization +### Block argument for Phi + +There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi` +instructions are modelled as block arguments in the SPIR-V dialect. (See the +[Rationale][Rationale] doc for "Block Arguments vs Phi nodes".) Each block +argument corresponds to one `OpPhi` instruction in the SPIR-V binary format. For +example, for the following SPIR-V function `foo`: + +```spirv + %foo = OpFunction %void None ... +%entry = OpLabel + %var = OpVariable %_ptr_Function_int Function + OpSelectionMerge %merge None + OpBranchConditional %true %true %false + %true = OpLabel + OpBranch %phi +%false = OpLabel + OpBranch %phi + %phi = OpLabel + %val = OpPhi %int %int_1 %false %int_0 %true + OpStore %var %val + OpReturn +%merge = OpLabel + OpReturn + OpFunctionEnd +``` + +It will be represented as: + +```mlir +func @foo() -> () { + %var = spv.Variable : !spv.ptr + + spv.selection { + %true = spv.constant true + spv.BranchConditional %true, ^true, ^false + + ^true: + %zero = spv.constant 0 : i32 + spv.Branch ^phi(%zero: i32) + + ^false: + %one = spv.constant 1 : i32 + spv.Branch ^phi(%one: i32) + + ^phi(%arg: i32): + spv.Store "Function" %var, %arg : i32 + spv.Return + + ^merge: + spv._merge + } + spv.Return +} +``` + +## Serialization and deserialization The serialization library provides two entry points, `mlir::spirv::serialize()` and `mlir::spirv::deserialize()`, for converting a MLIR SPIR-V module to binary @@ -399,6 +461,25 @@ the SPIR-V binary module and does not guarantee roundtrip equivalence (at least for now). For the latter, please use the assembler/disassembler in the [SPIRV-Tools][SPIRV-Tools] project. +A few transformations are performed in the process of serialization because of +the representational differences between SPIR-V dialect and binary format: + +* Attributes on `spv.module` are emitted as their corresponding SPIR-V + instructions. +* `spv.constant`s are unified and placed in the SPIR-V binary module section + for types, constants, and global variables. +* `spv.selection`s and `spv.loop`s are emitted as basic blocks with `Op*Merge` + instructions in the header block as required by the binary format. + +Similarly, a few transformations are performed during deserialization: + +* Instructions for execution environment requirements will be placed as + attribues on `spv.module`. +* `OpConstant*` instructions are materialized as `spv.constant` at each use + site. +* `OpPhi` instructions are converted to block arguments. +* Structured control flow are placed inside `spv.selection` and `spv.loop`. + [SPIR-V]: https://www.khronos.org/registry/spir-v/ [ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray [ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage @@ -406,3 +487,4 @@ for now). For the latter, please use the assembler/disassembler in the [RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray [StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure [SPIRV-Tools]: https://github.com/KhronosGroup/SPIRV-Tools +[Rationale]: https://github.com/tensorflow/mlir/blob/master/g3doc/Rationale.md#block-arguments-vs-phi-nodes diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index eb642c7db441..77e457e1b160 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -166,6 +166,7 @@ def SPV_OC_OpBitwiseXor : I32EnumAttrCase<"OpBitwiseXor", 198>; def SPV_OC_OpBitwiseAnd : I32EnumAttrCase<"OpBitwiseAnd", 199>; def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>; def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>; +def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>; def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>; def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>; def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; @@ -205,8 +206,8 @@ def SPV_OpcodeAttr : SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, - SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge, - SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, + SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpPhi, + SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, SPV_OC_OpModuleProcessed ]> { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index ac9e469fb032..11660ed4e874 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -41,8 +41,8 @@ using namespace mlir; #define DEBUG_TYPE "spirv-deserialization" -// Decodes a string literal in `words` starting at `wordIndex`. Update the -// latter to point to the position in words after the string literal. +/// Decodes a string literal in `words` starting at `wordIndex`. Update the +/// latter to point to the position in words after the string literal. static inline StringRef decodeStringLiteral(ArrayRef words, unsigned &wordIndex) { StringRef str(reinterpret_cast(words.data() + wordIndex)); @@ -50,11 +50,16 @@ static inline StringRef decodeStringLiteral(ArrayRef words, return str; } -// Extracts the opcode from the given first word of a SPIR-V instruction. +/// Extracts the opcode from the given first word of a SPIR-V instruction. static inline spirv::Opcode extractOpcode(uint32_t word) { return static_cast(word & 0xffff); } +/// Returns true if the given `block` is a function entry block. +static inline bool isFnEntryBlock(Block *block) { + return block->isEntryBlock() && isa_and_nonnull(block->getParentOp()); +} + namespace { /// A SPIR-V module serializer. /// @@ -130,6 +135,9 @@ private: /// them to their handler method accordingly. LogicalResult processFunction(ArrayRef operands); + /// Processes OpFunctionEnd and finalizes function. This wires up block + /// argument created from OpPhi instructions and also structurizes control + /// flow. LogicalResult processFunctionEnd(ArrayRef operands); /// Gets the constant's attribute and type associated with the given . @@ -220,6 +228,9 @@ private: // Control flow //===--------------------------------------------------------------------===// + /// Returns the block for the given label . + Block *getBlock(uint32_t id) const { return blockMap.lookup(id); } + // In SPIR-V, structured control flow is explicitly declared using merge // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect, // we use spv.selection and spv.loop to group structured control flow. @@ -242,9 +253,6 @@ private: // block and redirect all branches to the old header block to the old // merge block (which contains the spv.selection/spv.loop op now). - /// Returns the block for the given label . - Block *getBlock(uint32_t id) const { return blockMap.lookup(id); } - /// A struct for containing a header block's merge and continue targets. struct BlockMergeInfo { Block *mergeBlock; @@ -255,11 +263,24 @@ private: : mergeBlock(m), continueBlock(c) {} }; - /// Returns the merge and continue target info for the given `block` if it is - /// a header block. - BlockMergeInfo getBlockMergeInfo(Block *block) const { - return blockMergeInfo.lookup(block); - } + /// For OpPhi instructions, we use block arguments to represent them. OpPhi + /// encodes a list of (value, predecessor) pairs. At the time of handling the + /// block containing an OpPhi instruction, the predecessor block might not be + /// processed yet, also the value sent by it. So we need to defer handling + /// the block argument from the predecessors. We use the following approach: + /// + /// 1. For each OpPhi instruction, add a block argument to the current block + /// in construction. Record the block argment in `valueMap` so its uses + /// can be resolved. For the list of (value, predecessor) pairs, update + /// `blockPhiInfo` for bookkeeping. + /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each + /// block recorded there to create the proper block arguments on their + /// terminators. + + /// A data structure for containing a SPIR-V block's phi info. It will be + /// represented as block argument in SPIR-V dialect. + using BlockPhiInfo = + SmallVector; // The result of the values sent /// Gets or creates the block corresponding to the given label . The newly /// created block will always be placed at the end of the current function. @@ -278,6 +299,13 @@ private: /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`. LogicalResult processLoopMerge(ArrayRef operands); + /// Processes a SPIR-V OpPhi instruction with the given `operands`. + LogicalResult processPhi(ArrayRef operands); + + /// Creates block arguments on predecessors previously recorded when handling + /// OpPhi instructions. + LogicalResult wireUpBlockArgument(); + /// Extracts blocks belonging to a structured selection/loop into a /// spv.selection/spv.loop op. This method iterates until all blocks /// declared as selection/loop headers are handled. @@ -407,6 +435,9 @@ private: // Header block to its merge (and continue) target mapping. DenseMap blockMergeInfo; + // Block to its phi (block argument) mapping. + DenseMap blockPhiInfo; + // Result to value mapping. DenseMap valueMap; @@ -453,7 +484,8 @@ Deserializer::Deserializer(ArrayRef binary, MLIRContext *context) module(createModuleOp()), opBuilder(module->body()) {} LogicalResult Deserializer::deserialize() { - LLVM_DEBUG(llvm::dbgs() << "++ deserialization started\n"); + LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); + if (failed(processHeader())) return failure(); @@ -483,7 +515,7 @@ LogicalResult Deserializer::deserialize() { attachCapabilities(); attachExtensions(); - LLVM_DEBUG(llvm::dbgs() << "++ deserialization succeeded\n"); + LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n"); return success(); } @@ -748,10 +780,10 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { auto funcOp = opBuilder.create(unknownLoc, fnName, functionType, ArrayRef()); curFunction = funcMap[operands[1]] = funcOp; - LLVM_DEBUG(llvm::dbgs() << "[fn] processing function " << fnName << " (type=" - << fnType << ", id=" << operands[1] << ")\n"); + LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " + << fnType << ", id = " << operands[1] << ") --\n"); auto *entryBlock = funcOp.addEntryBlock(); - LLVM_DEBUG(llvm::dbgs() << "[block] created entry block @ " << entryBlock + LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock << "\n"); // Parse the op argument instructions @@ -810,8 +842,8 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { } if (opcode == spirv::Opcode::OpFunctionEnd) { LLVM_DEBUG(llvm::dbgs() - << "[fn] completed function '" << fnName << "' (type=" << fnType - << ", id=" << operands[1] << ")\n"); + << "-- completed function '" << fnName << "' (type = " << fnType + << ", id = " << operands[1] << ") --\n"); return processFunctionEnd(instOperands); } if (opcode != spirv::Opcode::OpLabel) { @@ -838,8 +870,8 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { return failure(); } - LLVM_DEBUG(llvm::dbgs() << "[fn] completed function '" << fnName << "' (type=" - << fnType << ", id=" << operands[1] << ")\n"); + LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " + << fnType << ", id = " << operands[1] << ") --\n"); return processFunctionEnd(instOperands); } @@ -849,8 +881,9 @@ LogicalResult Deserializer::processFunctionEnd(ArrayRef operands) { return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); } + // Wire up block arguments from OpPhi instructions. // Put all structured control flow in spv.selection/spv.loop ops. - if (failed(structurizeControlFlow())) { + if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { return failure(); } @@ -1438,7 +1471,7 @@ LogicalResult Deserializer::processConstantNull(ArrayRef operands) { Block *Deserializer::getOrCreateBlock(uint32_t id) { if (auto *block = getBlock(id)) { - LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id=" << id + LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id << " @ " << block << "\n"); return block; } @@ -1447,7 +1480,7 @@ Block *Deserializer::getOrCreateBlock(uint32_t id) { // or spv.loop or function). Create it into the function for now and sort // out the proper place later. auto *block = curFunction->addBlock(); - LLVM_DEBUG(llvm::dbgs() << "[block] created block for id=" << id << " @ " + LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ " << block << "\n"); return blockMap[id] = block; } @@ -1509,7 +1542,7 @@ LogicalResult Deserializer::processLabel(ArrayRef operands) { auto labelID = operands[0]; // We may have forward declared this block. auto *block = getOrCreateBlock(labelID); - LLVM_DEBUG(llvm::dbgs() << "[block] populating block @ " << block << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n"); // If we have seen this block, make sure it was just a forward declaration. assert(block->empty() && "re-deserialize the same block!"); @@ -1574,6 +1607,37 @@ LogicalResult Deserializer::processLoopMerge(ArrayRef operands) { return success(); } +LogicalResult Deserializer::processPhi(ArrayRef operands) { + if (!curBlock) { + return emitError(unknownLoc, "OpPhi must appear in a block"); + } + + if (operands.size() < 4) { + return emitError(unknownLoc, "OpPhi must specify result type, result , " + "and variable-parent pairs"); + } + + // Create a block argument for this OpPhi instruction. + Type blockArgType = getType(operands[0]); + BlockArgument *blockArg = curBlock->addArgument(blockArgType); + valueMap[operands[1]] = blockArg; + LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg + << " id = " << operands[1] << " of type " + << blockArgType << '\n'); + + // For each (value, predecessor) pair, insert the value to the predecessor's + // blockPhiInfo entry so later we can fix the block argument there. + for (unsigned i = 2, e = operands.size(); i < e; i += 2) { + uint32_t value = operands[i]; + Block *predecessor = getOrCreateBlock(operands[i + 1]); + blockPhiInfo[predecessor].push_back(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor + << " with arg id = " << value << '\n'); + } + + return success(); +} + namespace { /// A class for putting all blocks in a structured selection/loop in a /// spv.selection/spv.loop op. @@ -1711,6 +1775,14 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { mapper.map(block, newBlock); LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock << " from block " << block << "\n"); + if (!isFnEntryBlock(block)) { + for (BlockArgument *blockArg : block->getArguments()) { + auto *newArg = newBlock->addArgument(blockArg->getType()); + mapper.map(blockArg, newArg); + LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg + << " to " << newArg); + } + } for (auto &op : *block) newBlock->push_back(op.clone(mapper)); @@ -1758,7 +1830,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // If the function's header block is also part of the structured control // flow, we cannot just simply erase it because it may contain arguments // matching the function signature and used by the cloned blocks. - if (block->isEntryBlock() && isa(block->getParentOp())) { + if (isFnEntryBlock(block)) { LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block << " to only contain a spv.Branch op\n"); // Still keep the function entry block for the potential block arguments, @@ -1775,29 +1847,77 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { return success(); } +LogicalResult Deserializer::wireUpBlockArgument() { + LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); + + OpBuilder::InsertionGuard guard(opBuilder); + + for (const auto &info : blockPhiInfo) { + Block *block = info.first; + const BlockPhiInfo &phiInfo = info.second; + LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); + LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); + + // Set insertion point to before this block's terminator early because we + // may materialize ops via getValue() call. + auto *op = block->getTerminator(); + opBuilder.setInsertionPoint(op); + + SmallVector blockArgs; + blockArgs.reserve(phiInfo.size()); + for (uint32_t valueId : phiInfo) { + if (Value *value = getValue(valueId)) { + blockArgs.push_back(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value + << " id = " << valueId << '\n'); + } else { + return emitError(unknownLoc, "OpPhi references undefined value!"); + } + } + + if (auto branchOp = dyn_cast(op)) { + // Replace the previous branch op with a new one with block arguments. + opBuilder.create(branchOp.getLoc(), branchOp.getTarget(), + blockArgs); + branchOp.erase(); + } else { + return emitError(unknownLoc, "unimplemented terminator for Phi creation"); + } + + LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n"); + LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); + } + blockPhiInfo.clear(); + + LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n"); + return success(); +} + LogicalResult Deserializer::structurizeControlFlow() { LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); - while (!blockMergeInfo.empty()) { - auto *headerBlock = blockMergeInfo.begin()->first; - LLVM_DEBUG(llvm::dbgs() << "[cf] header block @ " << headerBlock << "\n"); + for (const auto &info : blockMergeInfo) { + auto *headerBlock = info.first; + LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << "\n"); - const auto &mergeInfo = blockMergeInfo.begin()->second; + const auto &mergeInfo = info.second; auto *mergeBlock = mergeInfo.mergeBlock; auto *continueBlock = mergeInfo.continueBlock; assert(mergeBlock && "merge block cannot be nullptr"); - LLVM_DEBUG(llvm::dbgs() << "[cf] merge block @ " << mergeBlock << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << "\n"); if (continueBlock) { LLVM_DEBUG(llvm::dbgs() - << "[cf] continue block @ " << continueBlock << "\n"); + << "[cf] continue block " << continueBlock << "\n"); } if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock, mergeBlock, continueBlock))) return failure(); - - blockMergeInfo.erase(headerBlock); } + blockMergeInfo.clear(); LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n"); return success(); @@ -1949,6 +2069,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processSelectionMerge(operands); case spirv::Opcode::OpLoopMerge: return processLoopMerge(operands); + case spirv::Opcode::OpPhi: + return processPhi(operands); case spirv::Opcode::OpUndef: return processUndef(operands); default: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 241be2a42975..afb9e0b81b7a 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -32,8 +32,11 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#define DEBUG_TYPE "spirv-serialization" + using namespace mlir; /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into @@ -49,6 +52,77 @@ LogicalResult encodeInstructionInto(SmallVectorImpl &binary, return success(); } +namespace { +/// A pre-order depth-first vistor for processing basic blocks. +/// +/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order +/// of blocks in a function must satisfy the rule that blocks appear before all +/// blocks they dominate." This can be achieved by a pre-order CFG traversal +/// algorithm. To make the serialization output more logical and readable to +/// human, we perform depth-first CFG traversal and delay the serialization of +/// the merge block and the continue block, if exists, until after all other +/// blocks have been processed. +/// +/// This visitor is special tailored for SPIR-V functions, spv.selection or +/// spv.loop block serialization to satisfy SPIR-V validation rules. It should +/// not be used as a general depth-first block visitor. +class PrettyBlockOrderVisitor { +public: + using BlockHandlerType = llvm::function_ref; + + /// Visits the basic blocks starting from the given `headerBlock`'s successors + /// in pre-order depth-first manner and calls `blockHandler` on each block. + /// Skips handling blocks in the `skipBlocks` list. If `headerBlock` is also + /// in `skipBlocks` list, still handles all its successors. + static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler, + ArrayRef skipBlocks = {}) { + return PrettyBlockOrderVisitor(blockHandler, skipBlocks) + .visitHeaderBlock(headerBlock); + } + +private: + PrettyBlockOrderVisitor(BlockHandlerType blockHandler, + ArrayRef skipBlocks) + : blockHandler(blockHandler), + doneBlocks(skipBlocks.begin(), skipBlocks.end()) {} + + LogicalResult visitHeaderBlock(Block *header) { + // Skip processing the header block if requested. + if (!llvm::is_contained(doneBlocks, header)) { + if (failed(blockHandler(header))) + return failure(); + doneBlocks.insert(header); + } + + for (auto *successor : header->getSuccessors()) { + if (failed(visitNormalBlock(successor))) + return failure(); + } + + return success(); + } + + LogicalResult visitNormalBlock(Block *block) { + if (doneBlocks.count(block)) + return success(); + + if (failed(blockHandler(block))) + return failure(); + doneBlocks.insert(block); + + for (auto *successor : block->getSuccessors()) { + if (failed(visitNormalBlock(successor))) + return failure(); + } + + return success(); + } + + BlockHandlerType blockHandler; + SmallPtrSet doneBlocks; +}; +} // namespace + namespace { /// A SPIR-V module serializer. @@ -75,6 +149,9 @@ public: /// Collects the final SPIR-V `binary`. void collect(SmallVectorImpl &binary); + /// (For debugging) prints each value and its corresponding result . + void printValueIDMap(raw_ostream &os); + private: // Note that there are two main categories of methods in this class: // * process*() methods are meant to fully serialize a SPIR-V module entity @@ -244,19 +321,25 @@ private: // Control flow //===--------------------------------------------------------------------===// + /// Returns the result for the given block. uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } - uint32_t assignBlockID(Block *block); + /// Returns the result for the given block. If no has been assigned, + /// assigns the next available + uint32_t getOrCreateBlockID(Block *block); - // Processes the given `block` and emits SPIR-V instructions for all ops - // inside. Does not emit OpLabel for this block if `omitLabel` is true. - // `actionBeforeTerminator` is a callback that will be invoked before handling - // the terminator op. It can be used to inject the Op*Merge instruction if - // this is a SPIR-V selection/loop header block. + /// Processes the given `block` and emits SPIR-V instructions for all ops + /// inside. Does not emit OpLabel for this block if `omitLabel` is true. + /// `actionBeforeTerminator` is a callback that will be invoked before + /// handling the terminator op. It can be used to inject the Op*Merge + /// instruction if this is a SPIR-V selection/loop header block. LogicalResult processBlock(Block *block, bool omitLabel = false, llvm::function_ref actionBeforeTerminator = nullptr); + /// Emits OpPhi instructions for the given block if it has block arguments. + LogicalResult emitPhiForBlockArguments(Block *block); + LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); LogicalResult processLoopOp(spirv::LoopOp loopOp); @@ -356,6 +439,46 @@ private: /// Map from extended instruction set name to s. llvm::StringMap extendedInstSetIDMap; + + /// Map from values used in OpPhi instructions to their offset in the + /// `functions` section. + /// + /// When processing a block with arguments, we need to emit OpPhi + /// instructions to record the predecessor block s and the values they + /// send to the block in question. But it's not guaranteed all values are + /// visited and thus assigned result s. So we need this list to capture + /// the offsets into `functions` where a value is used so that we can fix it + /// up later after processing all the blocks in a function. + /// + /// More concretely, say if we are visiting the following blocks: + /// + /// ```mlir + /// ^phi(%arg0: i32): + /// ... + /// ^parent1: + /// ... + /// spv.Branch ^phi(%val0: i32) + /// ^parent2: + /// ... + /// spv.Branch ^phi(%val1: i32) + /// ``` + /// + /// When we are serializing the `^phi` block, we need to emit at the beginning + /// of the block OpPhi instructions which has the following parameters: + /// + /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 + /// id-for-%val1 id-for-^parent2 + /// + /// But we don't know the for %val0 and %val1 yet. One way is to visit + /// all the blocks twice and use the first visit to assign an to each + /// value. But it's paying the overheads just for OpPhi emission. Instead, + /// we still visit the blocks once for emssion. When we emit the OpPhi + /// instructions, we use 0 as a placeholder for the s for %val0 and %val1. + /// At the same time, we record their offsets in the emitted binary (which is + /// placed inside `functions`) here. And then after emitting all blocks, we + /// replace the dummy 0 with the real result by overwriting + /// `functions[offset]`. + DenseMap> deferredPhiValues; }; } // namespace @@ -363,6 +486,8 @@ Serializer::Serializer(spirv::ModuleOp module) : module(module), mlirBuilder(module.getContext()) {} LogicalResult Serializer::serialize() { + LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); + if (failed(module.verify())) return failure(); @@ -378,6 +503,8 @@ LogicalResult Serializer::serialize() { return failure(); } } + + LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); return success(); } @@ -403,6 +530,24 @@ void Serializer::collect(SmallVectorImpl &binary) { binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); } + +void Serializer::printValueIDMap(raw_ostream &os) { + os << "\n= Value Map =\n\n"; + for (auto valueIDPair : valueIDMap) { + Value *val = valueIDPair.first; + os << " " << val << " " + << "id = " << valueIDPair.second << ' '; + if (auto *op = val->getDefiningOp()) { + os << "from op '" << op->getName() << "'"; + } else if (auto *arg = dyn_cast(val)) { + Block *block = arg->getOwner(); + os << "from argument of block " << block << ' '; + os << " in op '" << block->getParentOp()->getName() << "'"; + } + os << '\n'; + } +} + //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// @@ -564,6 +709,8 @@ Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, } // namespace LogicalResult Serializer::processFuncOp(FuncOp op) { + LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); + uint32_t fnTypeID = 0; // Generate type of the function. processType(op.getLoc(), op.getType(), fnTypeID); @@ -610,11 +757,24 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { return op.emitError("external function is unhandled"); } - for (auto &block : op) { - if (failed(processBlock(&block))) - return failure(); - } + if (failed(PrettyBlockOrderVisitor::visit( + &op.front(), [&](Block *block) { return processBlock(block); }))) + return failure(); + // There might be OpPhi instructions who have value references needing to fix. + for (auto deferredValue : deferredPhiValues) { + Value *value = deferredValue.first; + uint32_t id = getValueID(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value + << " to id = " << id << '\n'); + assert(id && "OpPhi references undefined value!"); + for (size_t offset : deferredValue.second) + functions[offset] = id; + } + deferredPhiValues.clear(); + + LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() + << "' --\n"); // Insert OpFunctionEnd. return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {}); } @@ -842,7 +1002,7 @@ Serializer::prepareFunctionType(Location loc, FunctionType type, SmallVectorImpl &operands) { typeEnum = spirv::Opcode::OpTypeFunction; assert(type.getNumResults() <= 1 && - "Serialization supports only a single return value"); + "serialization supports only a single return value"); uint32_t resultID = 0; if (failed(processType( loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), @@ -1221,24 +1381,31 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, // Control flow //===----------------------------------------------------------------------===// -uint32_t Serializer::assignBlockID(Block *block) { - assert(blockIDMap.lookup(block) == 0 && "block already has "); +uint32_t Serializer::getOrCreateBlockID(Block *block) { + if (uint32_t id = getBlockID(block)) + return id; return blockIDMap[block] = getNextID(); } LogicalResult Serializer::processBlock(Block *block, bool omitLabel, llvm::function_ref actionBeforeTerminator) { + LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); + LLVM_DEBUG(block->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); if (!omitLabel) { - auto blockID = getBlockID(block); - if (blockID == 0) { - blockID = assignBlockID(block); - } + uint32_t blockID = getOrCreateBlockID(block); + LLVM_DEBUG(llvm::dbgs() + << "[block] " << block << " (id = " << blockID << ")\n"); // Emit OpLabel for this block. encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID}); } + // Emit OpPhi instructions for block arguments, if any. + if (failed(emitPhiForBlockArguments(block))) + return failure(); + // Process each op in this block except the terminator. for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { if (failed(processOperation(&op))) @@ -1254,78 +1421,78 @@ Serializer::processBlock(Block *block, bool omitLabel, return success(); } -namespace { -/// A pre-order depth-first vistor for processing basic blocks in a spv.loop op. -/// -/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order -/// of blocks in a function must satisfy the rule that blocks appear before all -/// blocks they dominate." This can be achieved by a pre-order CFG traversal -/// algorithm. To make the serialization output more logical and readable to -/// human, we perform depth-first CFG traversal and delay the serialization of -/// the merge block (and the continue block) until after all other blocks have -/// been processed. -/// -/// This visitor is special tailored for spv.selection or spv.loop block -/// serialization to satisfy SPIR-V validation rules. It should not be used -/// as a general depth-first block visitor. -class ControlFlowBlockVisitor { -public: - using BlockHandlerType = llvm::function_ref; - - /// Visits the basic blocks starting from the given `headerBlock`'s successors - /// in pre-order depth-first manner and calls `blockHandler` on each block. - /// Skips handling the `headerBlock` and blocks in the `skipBlocks` list. - static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler, - ArrayRef skipBlocks) { - return ControlFlowBlockVisitor(blockHandler, skipBlocks) - .visitHeaderBlock(headerBlock); - } - -private: - ControlFlowBlockVisitor(BlockHandlerType blockHandler, - ArrayRef skipBlocks) - : blockHandler(blockHandler), - doneBlocks(skipBlocks.begin(), skipBlocks.end()) {} - - LogicalResult visitHeaderBlock(Block *header) { - // Skip processing the header block. - doneBlocks.insert(header); - - for (auto *successor : header->getSuccessors()) { - if (failed(visitNormalBlock(successor))) - return failure(); - } - +LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { + // Nothing to do if this block has no arguments or it's the entry block, which + // always has the same arguments as the function signature. + if (block->args_empty() || block->isEntryBlock()) return success(); + + // If the block has arguments, we need to create SPIR-V OpPhi instructions. + // A SPIR-V OpPhi instruction is of the syntax: + // OpPhi | result type | result | (value , parent block ) pair + // So we need to collect all predecessor blocks and the arguments they send + // to this block. + SmallVector, 4> predecessors; + for (Block *predecessor : block->getPredecessors()) { + auto *op = predecessor->getTerminator(); + if (auto branchOp = dyn_cast(op)) { + predecessors.emplace_back(predecessor, branchOp.operand_begin()); + } else { + return op->emitError("unimplemented terminator for Phi creation"); + } } - LogicalResult visitNormalBlock(Block *block) { - if (doneBlocks.count(block)) - return success(); + // Then create OpPhi instruction for each of the block argument. + for (auto argIndex : llvm::seq(0, block->getNumArguments())) { + BlockArgument *arg = block->getArgument(argIndex); - if (failed(blockHandler(block))) + // Get the type and result for this OpPhi instruction. + uint32_t phiTypeID = 0; + if (failed(processType(arg->getLoc(), arg->getType(), phiTypeID))) return failure(); - doneBlocks.insert(block); + uint32_t phiID = getNextID(); - for (auto *successor : block->getSuccessors()) { - if (failed(visitNormalBlock(successor))) - return failure(); + LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' + << arg << " (id = " << phiID << ")\n"); + + SmallVector phiArgs; + phiArgs.push_back(phiTypeID); + phiArgs.push_back(phiID); + + for (auto predIndex : llvm::seq(0, predecessors.size())) { + Value *value = *(predecessors[predIndex].second + argIndex); + uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); + LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId + << ") value " << value << ' '); + // Each pair is a value ... + uint32_t valueId = getValueID(value); + if (valueId == 0) { + // The op generating this value hasn't been visited yet so we don't have + // an assigned yet. Record this to fix up later. + LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); + deferredPhiValues[value].push_back(functions.size() + 1 + + phiArgs.size()); + } else { + LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); + } + phiArgs.push_back(valueId); + // ... and a parent block . + phiArgs.push_back(predBlockId); } - return success(); + encodeInstructionInto(functions, spirv::Opcode::OpPhi, phiArgs); + valueIDMap[arg] = phiID; } - BlockHandlerType blockHandler; - SmallPtrSet doneBlocks; -}; -} // namespace + return success(); +} LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { // Assign s to all blocks so that branches inside the SelectionOp can // resolve properly. auto &body = selectionOp.body(); for (Block &block : body) - assignBlockID(&block); + getOrCreateBlockID(&block); auto *headerBlock = selectionOp.getHeaderBlock(); auto *mergeBlock = selectionOp.getMergeBlock(); @@ -1353,8 +1520,8 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { // block. The selection header block and merge block are skipped by this // visitor. auto handleBlock = [&](Block *block) { return processBlock(block); }; - if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock, - {mergeBlock}))) + if (failed(PrettyBlockOrderVisitor::visit(headerBlock, handleBlock, + {headerBlock, mergeBlock}))) return failure(); // There is nothing to do for the merge block in the selection, which just @@ -1371,7 +1538,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { auto &body = loopOp.body(); for (Block &block : llvm::make_range(std::next(body.begin(), 1), body.end())) { - assignBlockID(&block); + getOrCreateBlockID(&block); } auto *headerBlock = loopOp.getHeaderBlock(); auto *continueBlock = loopOp.getContinueBlock(); @@ -1403,8 +1570,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // block. The loop header block, loop continue block, and loop merge block are // skipped by this visitor and handled later in this function. auto handleBlock = [&](Block *block) { return processBlock(block); }; - if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock, - {continueBlock, mergeBlock}))) + if (failed(PrettyBlockOrderVisitor::visit( + headerBlock, handleBlock, {headerBlock, continueBlock, mergeBlock}))) return failure(); // We have handled all other blocks. Now get to the loop continue block. @@ -1421,8 +1588,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { LogicalResult Serializer::processBranchConditionalOp( spirv::BranchConditionalOp condBranchOp) { auto conditionID = getValueID(condBranchOp.condition()); - auto trueLabelID = getBlockID(condBranchOp.getTrueBlock()); - auto falseLabelID = getBlockID(condBranchOp.getFalseBlock()); + auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); + auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); SmallVector arguments{conditionID, trueLabelID, falseLabelID}; if (auto weights = condBranchOp.branch_weights()) { @@ -1436,7 +1603,7 @@ LogicalResult Serializer::processBranchConditionalOp( LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { return encodeInstructionInto(functions, spirv::Opcode::OpBranch, - {getBlockID(branchOp.getTarget())}); + {getOrCreateBlockID(branchOp.getTarget())}); } //===----------------------------------------------------------------------===// @@ -1500,6 +1667,8 @@ Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { } LogicalResult Serializer::processOperation(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "[op] '" << op->getName() << "'\n"); + // First dispatch the ops that do not directly mirror an instruction from // the SPIR-V spec. if (auto addressOfOp = dyn_cast(op)) { @@ -1702,6 +1871,8 @@ LogicalResult spirv::serialize(spirv::ModuleOp module, if (failed(serializer.serialize())) return failure(); + LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs())); + serializer.collect(binary); return success(); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/phi.mlir b/mlir/test/Dialect/SPIRV/Serialization/phi.mlir new file mode 100644 index 000000000000..58e64e35f17b --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/phi.mlir @@ -0,0 +1,165 @@ +// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s + +// Test branch with one block argument + +spv.module "Logical" "GLSL450" { + func @foo() -> () { +// CHECK: %[[CST:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 +// CHECK-NEXT: spv.Branch ^bb1(%[[CST]] : i32) + spv.Branch ^bb1(%zero : i32) +// CHECK-NEXT: ^bb1(%{{.*}}: i32): + ^bb1(%arg0: i32): + spv.Return + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +} + +// ----- + +// Test branch with multiple block arguments + +spv.module "Logical" "GLSL450" { + func @foo() -> () { +// CHECK: %[[ZERO:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 +// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1 + %one = spv.constant 1.0 : f32 +// CHECK-NEXT: spv.Branch ^bb1(%[[ZERO]], %[[ONE]] : i32, f32) + spv.Branch ^bb1(%zero, %one : i32, f32) + +// CHECK-NEXT: ^bb1(%{{.*}}: i32, %{{.*}}: f32): // pred: ^bb0 + ^bb1(%arg0: i32, %arg1: f32): + spv.Return + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +} + +// ----- + +// Test using block arguments within branch + +spv.module "Logical" "GLSL450" { + func @foo() -> () { +// CHECK: %[[CST0:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 +// CHECK-NEXT: spv.Branch ^bb1(%[[CST0]] : i32) + spv.Branch ^bb1(%zero : i32) + +// CHECK-NEXT: ^bb1(%[[ARG:.*]]: i32): + ^bb1(%arg0: i32): +// CHECK-NEXT: %[[ADD:.*]] = spv.IAdd %[[ARG]], %[[ARG]] : i32 + %0 = spv.IAdd %arg0, %arg0 : i32 +// CHECK-NEXT: %[[CST1:.*]] = spv.constant 0 +// CHECK-NEXT: spv.Branch ^bb2(%[[CST1]], %[[ADD]] : i32, i32) + spv.Branch ^bb2(%zero, %0 : i32, i32) + +// CHECK-NEXT: ^bb2(%{{.*}}: i32, %{{.*}}: i32): + ^bb2(%arg1: i32, %arg2: i32): + spv.Return + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +} + +// ----- + +// Test block not following domination order + +spv.module "Logical" "GLSL450" { + func @foo() -> () { +// CHECK: spv.Branch ^bb1 + spv.Branch ^bb1 + +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: %[[ZERO:.*]] = spv.constant 0 +// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1 +// CHECK-NEXT: spv.Branch ^bb2(%[[ZERO]], %[[ONE]] : i32, f32) + +// CHECK-NEXT: ^bb2(%{{.*}}: i32, %{{.*}}: f32): + ^bb2(%arg0: i32, %arg1: f32): +// CHECK-NEXT: spv.Return + spv.Return + + // This block is reordered to follow domination order. + ^bb1: + %zero = spv.constant 0 : i32 + %one = spv.constant 1.0 : f32 + spv.Branch ^bb2(%zero, %one : i32, f32) + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +} + +// ----- + +// Test multiple predecessors + +spv.module "Logical" "GLSL450" { + func @foo() -> () { + %var = spv.Variable : !spv.ptr + +// CHECK: spv.selection + spv.selection { + %true = spv.constant true +// CHECK: spv.BranchConditional %{{.*}}, ^bb1, ^bb2 + spv.BranchConditional %true, ^true, ^false + +// CHECK-NEXT: ^bb1: + ^true: +// CHECK-NEXT: %[[ZERO:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 +// CHECK-NEXT: spv.Branch ^bb3(%[[ZERO]] : i32) + spv.Branch ^phi(%zero: i32) + +// CHECK-NEXT: ^bb2: + ^false: +// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1 + %one = spv.constant 1 : i32 +// CHECK-NEXT: spv.Branch ^bb3(%[[ONE]] : i32) + spv.Branch ^phi(%one: i32) + +// CHECK-NEXT: ^bb3(%[[ARG:.*]]: i32): + ^phi(%arg: i32): +// CHECK-NEXT: spv.Store "Function" %{{.*}}, %[[ARG]] : i32 + spv.Store "Function" %var, %arg : i32 +// CHECK-NEXT: spv.Return + spv.Return + +// CHECK-NEXT: ^bb4: + ^merge: +// CHECK-NEXT: spv._merge + spv._merge + } + spv.Return + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +}