diff --git a/mlir/g3doc/Dialects/SPIR-V.md b/mlir/g3doc/Dialects/SPIR-V.md index da6c80efb981..82922de6d11e 100644 --- a/mlir/g3doc/Dialects/SPIR-V.md +++ b/mlir/g3doc/Dialects/SPIR-V.md @@ -199,6 +199,12 @@ A SPIR-V function is defined using the builtin `func` op. `spv.module` verifies that the functions inside it comply with SPIR-V requirements: at most one result, no nested functions, and so on. +## Operations + +Operation documentation is written in each op's Op Definition Spec using +TableGen. A markdown version of the doc can be generated using `mlir-tblgen +-gen-doc`. + ## Control Flow SPIR-V binary format uses merge instructions (`OpSelectionMerge` and @@ -385,7 +391,63 @@ func @loop(%count : i32) -> () { } ``` -## Serialization +### Block argument for Phi + +There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi` +instructions are modelled as block arguments in the SPIR-V dialect. (See the +[Rationale][Rationale] doc for "Block Arguments vs Phi nodes".) Each block +argument corresponds to one `OpPhi` instruction in the SPIR-V binary format. For +example, for the following SPIR-V function `foo`: + +```spirv + %foo = OpFunction %void None ... +%entry = OpLabel + %var = OpVariable %_ptr_Function_int Function + OpSelectionMerge %merge None + OpBranchConditional %true %true %false + %true = OpLabel + OpBranch %phi +%false = OpLabel + OpBranch %phi + %phi = OpLabel + %val = OpPhi %int %int_1 %false %int_0 %true + OpStore %var %val + OpReturn +%merge = OpLabel + OpReturn + OpFunctionEnd +``` + +It will be represented as: + +```mlir +func @foo() -> () { + %var = spv.Variable : !spv.ptr + + spv.selection { + %true = spv.constant true + spv.BranchConditional %true, ^true, ^false + + ^true: + %zero = spv.constant 0 : i32 + spv.Branch ^phi(%zero: i32) + + ^false: + %one = spv.constant 1 : i32 + spv.Branch ^phi(%one: i32) + + ^phi(%arg: i32): + spv.Store "Function" %var, %arg : i32 + spv.Return + + ^merge: + spv._merge + } + spv.Return +} +``` + +## Serialization and deserialization The serialization library provides two entry points, `mlir::spirv::serialize()` and `mlir::spirv::deserialize()`, for converting a MLIR SPIR-V module to binary @@ -399,6 +461,25 @@ the SPIR-V binary module and does not guarantee roundtrip equivalence (at least for now). For the latter, please use the assembler/disassembler in the [SPIRV-Tools][SPIRV-Tools] project. +A few transformations are performed in the process of serialization because of +the representational differences between SPIR-V dialect and binary format: + +* Attributes on `spv.module` are emitted as their corresponding SPIR-V + instructions. +* `spv.constant`s are unified and placed in the SPIR-V binary module section + for types, constants, and global variables. +* `spv.selection`s and `spv.loop`s are emitted as basic blocks with `Op*Merge` + instructions in the header block as required by the binary format. + +Similarly, a few transformations are performed during deserialization: + +* Instructions for execution environment requirements will be placed as + attribues on `spv.module`. +* `OpConstant*` instructions are materialized as `spv.constant` at each use + site. +* `OpPhi` instructions are converted to block arguments. +* Structured control flow are placed inside `spv.selection` and `spv.loop`. + [SPIR-V]: https://www.khronos.org/registry/spir-v/ [ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray [ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage @@ -406,3 +487,4 @@ for now). For the latter, please use the assembler/disassembler in the [RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray [StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure [SPIRV-Tools]: https://github.com/KhronosGroup/SPIRV-Tools +[Rationale]: https://github.com/tensorflow/mlir/blob/master/g3doc/Rationale.md#block-arguments-vs-phi-nodes diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index eb642c7db441..77e457e1b160 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -166,6 +166,7 @@ def SPV_OC_OpBitwiseXor : I32EnumAttrCase<"OpBitwiseXor", 198>; def SPV_OC_OpBitwiseAnd : I32EnumAttrCase<"OpBitwiseAnd", 199>; def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>; def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>; +def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>; def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>; def SPV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerge", 247>; def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; @@ -205,8 +206,8 @@ def SPV_OpcodeAttr : SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, - SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge, - SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, + SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpPhi, + SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, SPV_OC_OpModuleProcessed ]> { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index ac9e469fb032..11660ed4e874 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -41,8 +41,8 @@ using namespace mlir; #define DEBUG_TYPE "spirv-deserialization" -// Decodes a string literal in `words` starting at `wordIndex`. Update the -// latter to point to the position in words after the string literal. +/// Decodes a string literal in `words` starting at `wordIndex`. Update the +/// latter to point to the position in words after the string literal. static inline StringRef decodeStringLiteral(ArrayRef words, unsigned &wordIndex) { StringRef str(reinterpret_cast(words.data() + wordIndex)); @@ -50,11 +50,16 @@ static inline StringRef decodeStringLiteral(ArrayRef words, return str; } -// Extracts the opcode from the given first word of a SPIR-V instruction. +/// Extracts the opcode from the given first word of a SPIR-V instruction. static inline spirv::Opcode extractOpcode(uint32_t word) { return static_cast(word & 0xffff); } +/// Returns true if the given `block` is a function entry block. +static inline bool isFnEntryBlock(Block *block) { + return block->isEntryBlock() && isa_and_nonnull(block->getParentOp()); +} + namespace { /// A SPIR-V module serializer. /// @@ -130,6 +135,9 @@ private: /// them to their handler method accordingly. LogicalResult processFunction(ArrayRef operands); + /// Processes OpFunctionEnd and finalizes function. This wires up block + /// argument created from OpPhi instructions and also structurizes control + /// flow. LogicalResult processFunctionEnd(ArrayRef operands); /// Gets the constant's attribute and type associated with the given . @@ -220,6 +228,9 @@ private: // Control flow //===--------------------------------------------------------------------===// + /// Returns the block for the given label . + Block *getBlock(uint32_t id) const { return blockMap.lookup(id); } + // In SPIR-V, structured control flow is explicitly declared using merge // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect, // we use spv.selection and spv.loop to group structured control flow. @@ -242,9 +253,6 @@ private: // block and redirect all branches to the old header block to the old // merge block (which contains the spv.selection/spv.loop op now). - /// Returns the block for the given label . - Block *getBlock(uint32_t id) const { return blockMap.lookup(id); } - /// A struct for containing a header block's merge and continue targets. struct BlockMergeInfo { Block *mergeBlock; @@ -255,11 +263,24 @@ private: : mergeBlock(m), continueBlock(c) {} }; - /// Returns the merge and continue target info for the given `block` if it is - /// a header block. - BlockMergeInfo getBlockMergeInfo(Block *block) const { - return blockMergeInfo.lookup(block); - } + /// For OpPhi instructions, we use block arguments to represent them. OpPhi + /// encodes a list of (value, predecessor) pairs. At the time of handling the + /// block containing an OpPhi instruction, the predecessor block might not be + /// processed yet, also the value sent by it. So we need to defer handling + /// the block argument from the predecessors. We use the following approach: + /// + /// 1. For each OpPhi instruction, add a block argument to the current block + /// in construction. Record the block argment in `valueMap` so its uses + /// can be resolved. For the list of (value, predecessor) pairs, update + /// `blockPhiInfo` for bookkeeping. + /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each + /// block recorded there to create the proper block arguments on their + /// terminators. + + /// A data structure for containing a SPIR-V block's phi info. It will be + /// represented as block argument in SPIR-V dialect. + using BlockPhiInfo = + SmallVector; // The result of the values sent /// Gets or creates the block corresponding to the given label . The newly /// created block will always be placed at the end of the current function. @@ -278,6 +299,13 @@ private: /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`. LogicalResult processLoopMerge(ArrayRef operands); + /// Processes a SPIR-V OpPhi instruction with the given `operands`. + LogicalResult processPhi(ArrayRef operands); + + /// Creates block arguments on predecessors previously recorded when handling + /// OpPhi instructions. + LogicalResult wireUpBlockArgument(); + /// Extracts blocks belonging to a structured selection/loop into a /// spv.selection/spv.loop op. This method iterates until all blocks /// declared as selection/loop headers are handled. @@ -407,6 +435,9 @@ private: // Header block to its merge (and continue) target mapping. DenseMap blockMergeInfo; + // Block to its phi (block argument) mapping. + DenseMap blockPhiInfo; + // Result to value mapping. DenseMap valueMap; @@ -453,7 +484,8 @@ Deserializer::Deserializer(ArrayRef binary, MLIRContext *context) module(createModuleOp()), opBuilder(module->body()) {} LogicalResult Deserializer::deserialize() { - LLVM_DEBUG(llvm::dbgs() << "++ deserialization started\n"); + LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); + if (failed(processHeader())) return failure(); @@ -483,7 +515,7 @@ LogicalResult Deserializer::deserialize() { attachCapabilities(); attachExtensions(); - LLVM_DEBUG(llvm::dbgs() << "++ deserialization succeeded\n"); + LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n"); return success(); } @@ -748,10 +780,10 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { auto funcOp = opBuilder.create(unknownLoc, fnName, functionType, ArrayRef()); curFunction = funcMap[operands[1]] = funcOp; - LLVM_DEBUG(llvm::dbgs() << "[fn] processing function " << fnName << " (type=" - << fnType << ", id=" << operands[1] << ")\n"); + LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " + << fnType << ", id = " << operands[1] << ") --\n"); auto *entryBlock = funcOp.addEntryBlock(); - LLVM_DEBUG(llvm::dbgs() << "[block] created entry block @ " << entryBlock + LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock << "\n"); // Parse the op argument instructions @@ -810,8 +842,8 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { } if (opcode == spirv::Opcode::OpFunctionEnd) { LLVM_DEBUG(llvm::dbgs() - << "[fn] completed function '" << fnName << "' (type=" << fnType - << ", id=" << operands[1] << ")\n"); + << "-- completed function '" << fnName << "' (type = " << fnType + << ", id = " << operands[1] << ") --\n"); return processFunctionEnd(instOperands); } if (opcode != spirv::Opcode::OpLabel) { @@ -838,8 +870,8 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { return failure(); } - LLVM_DEBUG(llvm::dbgs() << "[fn] completed function '" << fnName << "' (type=" - << fnType << ", id=" << operands[1] << ")\n"); + LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " + << fnType << ", id = " << operands[1] << ") --\n"); return processFunctionEnd(instOperands); } @@ -849,8 +881,9 @@ LogicalResult Deserializer::processFunctionEnd(ArrayRef operands) { return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); } + // Wire up block arguments from OpPhi instructions. // Put all structured control flow in spv.selection/spv.loop ops. - if (failed(structurizeControlFlow())) { + if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { return failure(); } @@ -1438,7 +1471,7 @@ LogicalResult Deserializer::processConstantNull(ArrayRef operands) { Block *Deserializer::getOrCreateBlock(uint32_t id) { if (auto *block = getBlock(id)) { - LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id=" << id + LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id << " @ " << block << "\n"); return block; } @@ -1447,7 +1480,7 @@ Block *Deserializer::getOrCreateBlock(uint32_t id) { // or spv.loop or function). Create it into the function for now and sort // out the proper place later. auto *block = curFunction->addBlock(); - LLVM_DEBUG(llvm::dbgs() << "[block] created block for id=" << id << " @ " + LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ " << block << "\n"); return blockMap[id] = block; } @@ -1509,7 +1542,7 @@ LogicalResult Deserializer::processLabel(ArrayRef operands) { auto labelID = operands[0]; // We may have forward declared this block. auto *block = getOrCreateBlock(labelID); - LLVM_DEBUG(llvm::dbgs() << "[block] populating block @ " << block << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n"); // If we have seen this block, make sure it was just a forward declaration. assert(block->empty() && "re-deserialize the same block!"); @@ -1574,6 +1607,37 @@ LogicalResult Deserializer::processLoopMerge(ArrayRef operands) { return success(); } +LogicalResult Deserializer::processPhi(ArrayRef operands) { + if (!curBlock) { + return emitError(unknownLoc, "OpPhi must appear in a block"); + } + + if (operands.size() < 4) { + return emitError(unknownLoc, "OpPhi must specify result type, result , " + "and variable-parent pairs"); + } + + // Create a block argument for this OpPhi instruction. + Type blockArgType = getType(operands[0]); + BlockArgument *blockArg = curBlock->addArgument(blockArgType); + valueMap[operands[1]] = blockArg; + LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg + << " id = " << operands[1] << " of type " + << blockArgType << '\n'); + + // For each (value, predecessor) pair, insert the value to the predecessor's + // blockPhiInfo entry so later we can fix the block argument there. + for (unsigned i = 2, e = operands.size(); i < e; i += 2) { + uint32_t value = operands[i]; + Block *predecessor = getOrCreateBlock(operands[i + 1]); + blockPhiInfo[predecessor].push_back(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor + << " with arg id = " << value << '\n'); + } + + return success(); +} + namespace { /// A class for putting all blocks in a structured selection/loop in a /// spv.selection/spv.loop op. @@ -1711,6 +1775,14 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { mapper.map(block, newBlock); LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock << " from block " << block << "\n"); + if (!isFnEntryBlock(block)) { + for (BlockArgument *blockArg : block->getArguments()) { + auto *newArg = newBlock->addArgument(blockArg->getType()); + mapper.map(blockArg, newArg); + LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg + << " to " << newArg); + } + } for (auto &op : *block) newBlock->push_back(op.clone(mapper)); @@ -1758,7 +1830,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // If the function's header block is also part of the structured control // flow, we cannot just simply erase it because it may contain arguments // matching the function signature and used by the cloned blocks. - if (block->isEntryBlock() && isa(block->getParentOp())) { + if (isFnEntryBlock(block)) { LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block << " to only contain a spv.Branch op\n"); // Still keep the function entry block for the potential block arguments, @@ -1775,29 +1847,77 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { return success(); } +LogicalResult Deserializer::wireUpBlockArgument() { + LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); + + OpBuilder::InsertionGuard guard(opBuilder); + + for (const auto &info : blockPhiInfo) { + Block *block = info.first; + const BlockPhiInfo &phiInfo = info.second; + LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); + LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); + + // Set insertion point to before this block's terminator early because we + // may materialize ops via getValue() call. + auto *op = block->getTerminator(); + opBuilder.setInsertionPoint(op); + + SmallVector blockArgs; + blockArgs.reserve(phiInfo.size()); + for (uint32_t valueId : phiInfo) { + if (Value *value = getValue(valueId)) { + blockArgs.push_back(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value + << " id = " << valueId << '\n'); + } else { + return emitError(unknownLoc, "OpPhi references undefined value!"); + } + } + + if (auto branchOp = dyn_cast(op)) { + // Replace the previous branch op with a new one with block arguments. + opBuilder.create(branchOp.getLoc(), branchOp.getTarget(), + blockArgs); + branchOp.erase(); + } else { + return emitError(unknownLoc, "unimplemented terminator for Phi creation"); + } + + LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n"); + LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); + } + blockPhiInfo.clear(); + + LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n"); + return success(); +} + LogicalResult Deserializer::structurizeControlFlow() { LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); - while (!blockMergeInfo.empty()) { - auto *headerBlock = blockMergeInfo.begin()->first; - LLVM_DEBUG(llvm::dbgs() << "[cf] header block @ " << headerBlock << "\n"); + for (const auto &info : blockMergeInfo) { + auto *headerBlock = info.first; + LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << "\n"); - const auto &mergeInfo = blockMergeInfo.begin()->second; + const auto &mergeInfo = info.second; auto *mergeBlock = mergeInfo.mergeBlock; auto *continueBlock = mergeInfo.continueBlock; assert(mergeBlock && "merge block cannot be nullptr"); - LLVM_DEBUG(llvm::dbgs() << "[cf] merge block @ " << mergeBlock << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << "\n"); if (continueBlock) { LLVM_DEBUG(llvm::dbgs() - << "[cf] continue block @ " << continueBlock << "\n"); + << "[cf] continue block " << continueBlock << "\n"); } if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock, mergeBlock, continueBlock))) return failure(); - - blockMergeInfo.erase(headerBlock); } + blockMergeInfo.clear(); LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n"); return success(); @@ -1949,6 +2069,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processSelectionMerge(operands); case spirv::Opcode::OpLoopMerge: return processLoopMerge(operands); + case spirv::Opcode::OpPhi: + return processPhi(operands); case spirv::Opcode::OpUndef: return processUndef(operands); default: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 241be2a42975..afb9e0b81b7a 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -32,8 +32,11 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#define DEBUG_TYPE "spirv-serialization" + using namespace mlir; /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into @@ -49,6 +52,77 @@ LogicalResult encodeInstructionInto(SmallVectorImpl &binary, return success(); } +namespace { +/// A pre-order depth-first vistor for processing basic blocks. +/// +/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order +/// of blocks in a function must satisfy the rule that blocks appear before all +/// blocks they dominate." This can be achieved by a pre-order CFG traversal +/// algorithm. To make the serialization output more logical and readable to +/// human, we perform depth-first CFG traversal and delay the serialization of +/// the merge block and the continue block, if exists, until after all other +/// blocks have been processed. +/// +/// This visitor is special tailored for SPIR-V functions, spv.selection or +/// spv.loop block serialization to satisfy SPIR-V validation rules. It should +/// not be used as a general depth-first block visitor. +class PrettyBlockOrderVisitor { +public: + using BlockHandlerType = llvm::function_ref; + + /// Visits the basic blocks starting from the given `headerBlock`'s successors + /// in pre-order depth-first manner and calls `blockHandler` on each block. + /// Skips handling blocks in the `skipBlocks` list. If `headerBlock` is also + /// in `skipBlocks` list, still handles all its successors. + static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler, + ArrayRef skipBlocks = {}) { + return PrettyBlockOrderVisitor(blockHandler, skipBlocks) + .visitHeaderBlock(headerBlock); + } + +private: + PrettyBlockOrderVisitor(BlockHandlerType blockHandler, + ArrayRef skipBlocks) + : blockHandler(blockHandler), + doneBlocks(skipBlocks.begin(), skipBlocks.end()) {} + + LogicalResult visitHeaderBlock(Block *header) { + // Skip processing the header block if requested. + if (!llvm::is_contained(doneBlocks, header)) { + if (failed(blockHandler(header))) + return failure(); + doneBlocks.insert(header); + } + + for (auto *successor : header->getSuccessors()) { + if (failed(visitNormalBlock(successor))) + return failure(); + } + + return success(); + } + + LogicalResult visitNormalBlock(Block *block) { + if (doneBlocks.count(block)) + return success(); + + if (failed(blockHandler(block))) + return failure(); + doneBlocks.insert(block); + + for (auto *successor : block->getSuccessors()) { + if (failed(visitNormalBlock(successor))) + return failure(); + } + + return success(); + } + + BlockHandlerType blockHandler; + SmallPtrSet doneBlocks; +}; +} // namespace + namespace { /// A SPIR-V module serializer. @@ -75,6 +149,9 @@ public: /// Collects the final SPIR-V `binary`. void collect(SmallVectorImpl &binary); + /// (For debugging) prints each value and its corresponding result . + void printValueIDMap(raw_ostream &os); + private: // Note that there are two main categories of methods in this class: // * process*() methods are meant to fully serialize a SPIR-V module entity @@ -244,19 +321,25 @@ private: // Control flow //===--------------------------------------------------------------------===// + /// Returns the result for the given block. uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } - uint32_t assignBlockID(Block *block); + /// Returns the result for the given block. If no has been assigned, + /// assigns the next available + uint32_t getOrCreateBlockID(Block *block); - // Processes the given `block` and emits SPIR-V instructions for all ops - // inside. Does not emit OpLabel for this block if `omitLabel` is true. - // `actionBeforeTerminator` is a callback that will be invoked before handling - // the terminator op. It can be used to inject the Op*Merge instruction if - // this is a SPIR-V selection/loop header block. + /// Processes the given `block` and emits SPIR-V instructions for all ops + /// inside. Does not emit OpLabel for this block if `omitLabel` is true. + /// `actionBeforeTerminator` is a callback that will be invoked before + /// handling the terminator op. It can be used to inject the Op*Merge + /// instruction if this is a SPIR-V selection/loop header block. LogicalResult processBlock(Block *block, bool omitLabel = false, llvm::function_ref actionBeforeTerminator = nullptr); + /// Emits OpPhi instructions for the given block if it has block arguments. + LogicalResult emitPhiForBlockArguments(Block *block); + LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); LogicalResult processLoopOp(spirv::LoopOp loopOp); @@ -356,6 +439,46 @@ private: /// Map from extended instruction set name to s. llvm::StringMap extendedInstSetIDMap; + + /// Map from values used in OpPhi instructions to their offset in the + /// `functions` section. + /// + /// When processing a block with arguments, we need to emit OpPhi + /// instructions to record the predecessor block s and the values they + /// send to the block in question. But it's not guaranteed all values are + /// visited and thus assigned result s. So we need this list to capture + /// the offsets into `functions` where a value is used so that we can fix it + /// up later after processing all the blocks in a function. + /// + /// More concretely, say if we are visiting the following blocks: + /// + /// ```mlir + /// ^phi(%arg0: i32): + /// ... + /// ^parent1: + /// ... + /// spv.Branch ^phi(%val0: i32) + /// ^parent2: + /// ... + /// spv.Branch ^phi(%val1: i32) + /// ``` + /// + /// When we are serializing the `^phi` block, we need to emit at the beginning + /// of the block OpPhi instructions which has the following parameters: + /// + /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 + /// id-for-%val1 id-for-^parent2 + /// + /// But we don't know the for %val0 and %val1 yet. One way is to visit + /// all the blocks twice and use the first visit to assign an to each + /// value. But it's paying the overheads just for OpPhi emission. Instead, + /// we still visit the blocks once for emssion. When we emit the OpPhi + /// instructions, we use 0 as a placeholder for the s for %val0 and %val1. + /// At the same time, we record their offsets in the emitted binary (which is + /// placed inside `functions`) here. And then after emitting all blocks, we + /// replace the dummy 0 with the real result by overwriting + /// `functions[offset]`. + DenseMap> deferredPhiValues; }; } // namespace @@ -363,6 +486,8 @@ Serializer::Serializer(spirv::ModuleOp module) : module(module), mlirBuilder(module.getContext()) {} LogicalResult Serializer::serialize() { + LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); + if (failed(module.verify())) return failure(); @@ -378,6 +503,8 @@ LogicalResult Serializer::serialize() { return failure(); } } + + LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); return success(); } @@ -403,6 +530,24 @@ void Serializer::collect(SmallVectorImpl &binary) { binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); } + +void Serializer::printValueIDMap(raw_ostream &os) { + os << "\n= Value Map =\n\n"; + for (auto valueIDPair : valueIDMap) { + Value *val = valueIDPair.first; + os << " " << val << " " + << "id = " << valueIDPair.second << ' '; + if (auto *op = val->getDefiningOp()) { + os << "from op '" << op->getName() << "'"; + } else if (auto *arg = dyn_cast(val)) { + Block *block = arg->getOwner(); + os << "from argument of block " << block << ' '; + os << " in op '" << block->getParentOp()->getName() << "'"; + } + os << '\n'; + } +} + //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// @@ -564,6 +709,8 @@ Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, } // namespace LogicalResult Serializer::processFuncOp(FuncOp op) { + LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); + uint32_t fnTypeID = 0; // Generate type of the function. processType(op.getLoc(), op.getType(), fnTypeID); @@ -610,11 +757,24 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { return op.emitError("external function is unhandled"); } - for (auto &block : op) { - if (failed(processBlock(&block))) - return failure(); - } + if (failed(PrettyBlockOrderVisitor::visit( + &op.front(), [&](Block *block) { return processBlock(block); }))) + return failure(); + // There might be OpPhi instructions who have value references needing to fix. + for (auto deferredValue : deferredPhiValues) { + Value *value = deferredValue.first; + uint32_t id = getValueID(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value + << " to id = " << id << '\n'); + assert(id && "OpPhi references undefined value!"); + for (size_t offset : deferredValue.second) + functions[offset] = id; + } + deferredPhiValues.clear(); + + LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() + << "' --\n"); // Insert OpFunctionEnd. return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {}); } @@ -842,7 +1002,7 @@ Serializer::prepareFunctionType(Location loc, FunctionType type, SmallVectorImpl &operands) { typeEnum = spirv::Opcode::OpTypeFunction; assert(type.getNumResults() <= 1 && - "Serialization supports only a single return value"); + "serialization supports only a single return value"); uint32_t resultID = 0; if (failed(processType( loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), @@ -1221,24 +1381,31 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, // Control flow //===----------------------------------------------------------------------===// -uint32_t Serializer::assignBlockID(Block *block) { - assert(blockIDMap.lookup(block) == 0 && "block already has "); +uint32_t Serializer::getOrCreateBlockID(Block *block) { + if (uint32_t id = getBlockID(block)) + return id; return blockIDMap[block] = getNextID(); } LogicalResult Serializer::processBlock(Block *block, bool omitLabel, llvm::function_ref actionBeforeTerminator) { + LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); + LLVM_DEBUG(block->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); if (!omitLabel) { - auto blockID = getBlockID(block); - if (blockID == 0) { - blockID = assignBlockID(block); - } + uint32_t blockID = getOrCreateBlockID(block); + LLVM_DEBUG(llvm::dbgs() + << "[block] " << block << " (id = " << blockID << ")\n"); // Emit OpLabel for this block. encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID}); } + // Emit OpPhi instructions for block arguments, if any. + if (failed(emitPhiForBlockArguments(block))) + return failure(); + // Process each op in this block except the terminator. for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { if (failed(processOperation(&op))) @@ -1254,78 +1421,78 @@ Serializer::processBlock(Block *block, bool omitLabel, return success(); } -namespace { -/// A pre-order depth-first vistor for processing basic blocks in a spv.loop op. -/// -/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order -/// of blocks in a function must satisfy the rule that blocks appear before all -/// blocks they dominate." This can be achieved by a pre-order CFG traversal -/// algorithm. To make the serialization output more logical and readable to -/// human, we perform depth-first CFG traversal and delay the serialization of -/// the merge block (and the continue block) until after all other blocks have -/// been processed. -/// -/// This visitor is special tailored for spv.selection or spv.loop block -/// serialization to satisfy SPIR-V validation rules. It should not be used -/// as a general depth-first block visitor. -class ControlFlowBlockVisitor { -public: - using BlockHandlerType = llvm::function_ref; - - /// Visits the basic blocks starting from the given `headerBlock`'s successors - /// in pre-order depth-first manner and calls `blockHandler` on each block. - /// Skips handling the `headerBlock` and blocks in the `skipBlocks` list. - static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler, - ArrayRef skipBlocks) { - return ControlFlowBlockVisitor(blockHandler, skipBlocks) - .visitHeaderBlock(headerBlock); - } - -private: - ControlFlowBlockVisitor(BlockHandlerType blockHandler, - ArrayRef skipBlocks) - : blockHandler(blockHandler), - doneBlocks(skipBlocks.begin(), skipBlocks.end()) {} - - LogicalResult visitHeaderBlock(Block *header) { - // Skip processing the header block. - doneBlocks.insert(header); - - for (auto *successor : header->getSuccessors()) { - if (failed(visitNormalBlock(successor))) - return failure(); - } - +LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { + // Nothing to do if this block has no arguments or it's the entry block, which + // always has the same arguments as the function signature. + if (block->args_empty() || block->isEntryBlock()) return success(); + + // If the block has arguments, we need to create SPIR-V OpPhi instructions. + // A SPIR-V OpPhi instruction is of the syntax: + // OpPhi | result type | result | (value , parent block ) pair + // So we need to collect all predecessor blocks and the arguments they send + // to this block. + SmallVector, 4> predecessors; + for (Block *predecessor : block->getPredecessors()) { + auto *op = predecessor->getTerminator(); + if (auto branchOp = dyn_cast(op)) { + predecessors.emplace_back(predecessor, branchOp.operand_begin()); + } else { + return op->emitError("unimplemented terminator for Phi creation"); + } } - LogicalResult visitNormalBlock(Block *block) { - if (doneBlocks.count(block)) - return success(); + // Then create OpPhi instruction for each of the block argument. + for (auto argIndex : llvm::seq(0, block->getNumArguments())) { + BlockArgument *arg = block->getArgument(argIndex); - if (failed(blockHandler(block))) + // Get the type and result for this OpPhi instruction. + uint32_t phiTypeID = 0; + if (failed(processType(arg->getLoc(), arg->getType(), phiTypeID))) return failure(); - doneBlocks.insert(block); + uint32_t phiID = getNextID(); - for (auto *successor : block->getSuccessors()) { - if (failed(visitNormalBlock(successor))) - return failure(); + LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' + << arg << " (id = " << phiID << ")\n"); + + SmallVector phiArgs; + phiArgs.push_back(phiTypeID); + phiArgs.push_back(phiID); + + for (auto predIndex : llvm::seq(0, predecessors.size())) { + Value *value = *(predecessors[predIndex].second + argIndex); + uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); + LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId + << ") value " << value << ' '); + // Each pair is a value ... + uint32_t valueId = getValueID(value); + if (valueId == 0) { + // The op generating this value hasn't been visited yet so we don't have + // an assigned yet. Record this to fix up later. + LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); + deferredPhiValues[value].push_back(functions.size() + 1 + + phiArgs.size()); + } else { + LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); + } + phiArgs.push_back(valueId); + // ... and a parent block . + phiArgs.push_back(predBlockId); } - return success(); + encodeInstructionInto(functions, spirv::Opcode::OpPhi, phiArgs); + valueIDMap[arg] = phiID; } - BlockHandlerType blockHandler; - SmallPtrSet doneBlocks; -}; -} // namespace + return success(); +} LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { // Assign s to all blocks so that branches inside the SelectionOp can // resolve properly. auto &body = selectionOp.body(); for (Block &block : body) - assignBlockID(&block); + getOrCreateBlockID(&block); auto *headerBlock = selectionOp.getHeaderBlock(); auto *mergeBlock = selectionOp.getMergeBlock(); @@ -1353,8 +1520,8 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { // block. The selection header block and merge block are skipped by this // visitor. auto handleBlock = [&](Block *block) { return processBlock(block); }; - if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock, - {mergeBlock}))) + if (failed(PrettyBlockOrderVisitor::visit(headerBlock, handleBlock, + {headerBlock, mergeBlock}))) return failure(); // There is nothing to do for the merge block in the selection, which just @@ -1371,7 +1538,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { auto &body = loopOp.body(); for (Block &block : llvm::make_range(std::next(body.begin(), 1), body.end())) { - assignBlockID(&block); + getOrCreateBlockID(&block); } auto *headerBlock = loopOp.getHeaderBlock(); auto *continueBlock = loopOp.getContinueBlock(); @@ -1403,8 +1570,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // block. The loop header block, loop continue block, and loop merge block are // skipped by this visitor and handled later in this function. auto handleBlock = [&](Block *block) { return processBlock(block); }; - if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock, - {continueBlock, mergeBlock}))) + if (failed(PrettyBlockOrderVisitor::visit( + headerBlock, handleBlock, {headerBlock, continueBlock, mergeBlock}))) return failure(); // We have handled all other blocks. Now get to the loop continue block. @@ -1421,8 +1588,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { LogicalResult Serializer::processBranchConditionalOp( spirv::BranchConditionalOp condBranchOp) { auto conditionID = getValueID(condBranchOp.condition()); - auto trueLabelID = getBlockID(condBranchOp.getTrueBlock()); - auto falseLabelID = getBlockID(condBranchOp.getFalseBlock()); + auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); + auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); SmallVector arguments{conditionID, trueLabelID, falseLabelID}; if (auto weights = condBranchOp.branch_weights()) { @@ -1436,7 +1603,7 @@ LogicalResult Serializer::processBranchConditionalOp( LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { return encodeInstructionInto(functions, spirv::Opcode::OpBranch, - {getBlockID(branchOp.getTarget())}); + {getOrCreateBlockID(branchOp.getTarget())}); } //===----------------------------------------------------------------------===// @@ -1500,6 +1667,8 @@ Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { } LogicalResult Serializer::processOperation(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "[op] '" << op->getName() << "'\n"); + // First dispatch the ops that do not directly mirror an instruction from // the SPIR-V spec. if (auto addressOfOp = dyn_cast(op)) { @@ -1702,6 +1871,8 @@ LogicalResult spirv::serialize(spirv::ModuleOp module, if (failed(serializer.serialize())) return failure(); + LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs())); + serializer.collect(binary); return success(); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/phi.mlir b/mlir/test/Dialect/SPIRV/Serialization/phi.mlir new file mode 100644 index 000000000000..58e64e35f17b --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/phi.mlir @@ -0,0 +1,165 @@ +// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s + +// Test branch with one block argument + +spv.module "Logical" "GLSL450" { + func @foo() -> () { +// CHECK: %[[CST:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 +// CHECK-NEXT: spv.Branch ^bb1(%[[CST]] : i32) + spv.Branch ^bb1(%zero : i32) +// CHECK-NEXT: ^bb1(%{{.*}}: i32): + ^bb1(%arg0: i32): + spv.Return + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +} + +// ----- + +// Test branch with multiple block arguments + +spv.module "Logical" "GLSL450" { + func @foo() -> () { +// CHECK: %[[ZERO:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 +// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1 + %one = spv.constant 1.0 : f32 +// CHECK-NEXT: spv.Branch ^bb1(%[[ZERO]], %[[ONE]] : i32, f32) + spv.Branch ^bb1(%zero, %one : i32, f32) + +// CHECK-NEXT: ^bb1(%{{.*}}: i32, %{{.*}}: f32): // pred: ^bb0 + ^bb1(%arg0: i32, %arg1: f32): + spv.Return + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +} + +// ----- + +// Test using block arguments within branch + +spv.module "Logical" "GLSL450" { + func @foo() -> () { +// CHECK: %[[CST0:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 +// CHECK-NEXT: spv.Branch ^bb1(%[[CST0]] : i32) + spv.Branch ^bb1(%zero : i32) + +// CHECK-NEXT: ^bb1(%[[ARG:.*]]: i32): + ^bb1(%arg0: i32): +// CHECK-NEXT: %[[ADD:.*]] = spv.IAdd %[[ARG]], %[[ARG]] : i32 + %0 = spv.IAdd %arg0, %arg0 : i32 +// CHECK-NEXT: %[[CST1:.*]] = spv.constant 0 +// CHECK-NEXT: spv.Branch ^bb2(%[[CST1]], %[[ADD]] : i32, i32) + spv.Branch ^bb2(%zero, %0 : i32, i32) + +// CHECK-NEXT: ^bb2(%{{.*}}: i32, %{{.*}}: i32): + ^bb2(%arg1: i32, %arg2: i32): + spv.Return + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +} + +// ----- + +// Test block not following domination order + +spv.module "Logical" "GLSL450" { + func @foo() -> () { +// CHECK: spv.Branch ^bb1 + spv.Branch ^bb1 + +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: %[[ZERO:.*]] = spv.constant 0 +// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1 +// CHECK-NEXT: spv.Branch ^bb2(%[[ZERO]], %[[ONE]] : i32, f32) + +// CHECK-NEXT: ^bb2(%{{.*}}: i32, %{{.*}}: f32): + ^bb2(%arg0: i32, %arg1: f32): +// CHECK-NEXT: spv.Return + spv.Return + + // This block is reordered to follow domination order. + ^bb1: + %zero = spv.constant 0 : i32 + %one = spv.constant 1.0 : f32 + spv.Branch ^bb2(%zero, %one : i32, f32) + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +} + +// ----- + +// Test multiple predecessors + +spv.module "Logical" "GLSL450" { + func @foo() -> () { + %var = spv.Variable : !spv.ptr + +// CHECK: spv.selection + spv.selection { + %true = spv.constant true +// CHECK: spv.BranchConditional %{{.*}}, ^bb1, ^bb2 + spv.BranchConditional %true, ^true, ^false + +// CHECK-NEXT: ^bb1: + ^true: +// CHECK-NEXT: %[[ZERO:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 +// CHECK-NEXT: spv.Branch ^bb3(%[[ZERO]] : i32) + spv.Branch ^phi(%zero: i32) + +// CHECK-NEXT: ^bb2: + ^false: +// CHECK-NEXT: %[[ONE:.*]] = spv.constant 1 + %one = spv.constant 1 : i32 +// CHECK-NEXT: spv.Branch ^bb3(%[[ONE]] : i32) + spv.Branch ^phi(%one: i32) + +// CHECK-NEXT: ^bb3(%[[ARG:.*]]: i32): + ^phi(%arg: i32): +// CHECK-NEXT: spv.Store "Function" %{{.*}}, %[[ARG]] : i32 + spv.Store "Function" %var, %arg : i32 +// CHECK-NEXT: spv.Return + spv.Return + +// CHECK-NEXT: ^bb4: + ^merge: +// CHECK-NEXT: spv._merge + spv._merge + } + spv.Return + } + + func @main() -> () { + spv.Return + } + spv.EntryPoint "GLCompute" @main +} attributes { + capabilities = ["Shader"] +}