diff --git a/mlir/g3doc/Dialects/SPIR-V.md b/mlir/g3doc/Dialects/SPIR-V.md index e149540546a5..0a6bc70c1a76 100644 --- a/mlir/g3doc/Dialects/SPIR-V.md +++ b/mlir/g3doc/Dialects/SPIR-V.md @@ -18,44 +18,66 @@ but not necessarily compiler transformations. The purpose of the SPIR-V dialect is to serve as the "proxy" of the binary format and to facilitate transformations. Therefore, it should -* Be trivial to serialize into the SPIR-V binary format; -* Stay as the same semantic level and try to be a mechanical 1:1 mapping; -* But deviate representationally if possible with MLIR mechanisms. +* Stay as the same semantic level and try to be a mechanical 1:1 mapping; +* But deviate representationally if possible with MLIR mechanisms. +* Be straightforward to serialize into and deserialize drom the SPIR-V binary + format. ## Conventions The SPIR-V dialect has the following conventions: -* The prefix for all SPIR-V types and operations are `spv.`. -* Ops that directly correspond to instructions in the binary format have - `CamelCase` names, for example, `spv.FMul`; -* Otherwise they have `snake_case` names. These ops are mostly for defining - the SPIR-V structure, inclduing module, function, and module-level ops. - For example, `spv.module`, `spv.constant`. +* The prefix for all SPIR-V types and operations are `spv.`. +* Ops that directly mirror instructions in the binary format have `CamelCase` + names that are the same as the instruction opnames (without the `Op` + prefix). For example, `spv.FMul` is a direct mirror of `OpFMul`. They will + be serialized into and deserialized from one instruction. +* Ops with `snake_case` names are those that have different representation + from corresponding instructions (or concepts) in the binary format. These + ops are mostly for defining the SPIR-V structure. For example, `spv.module` + and `spv.constant`. They may correspond to zero or more instructions during + (de)serialization. +* Ops with `_snake_case` names are those that have no corresponding + instructions (or concepts) in the binary format. They are introduced to + satisfy MLIR structural requirements. For example, `spv._module_end` and + `spv._merge`. They maps to no instructions during (de)serialization. ## Module A SPIR-V module is defined via the `spv.module` op, which has one region that contains one block. Model-level instructions, including function definitions, -are all placed inside the block. Functions are defined using the standard `func` +are all placed inside the block. Functions are defined using the builtin `func` op. Compared to the binary format, we adjust how certain module-level SPIR-V instructions are represented in the SPIR-V dialect. Notably, -* Requirements for capabilities, extensions, extended instruction sets, - addressing model, and memory model is conveyed using `spv.module` attributes. - This is considered better because these information are for the - exexcution environment. It's eaiser to probe them if on the module op - itself. -* Annotations/decoration instrutions are "folded" into the instructions they - decorate and represented as attributes on those ops. This elimiates potential - forward references of SSA values, improves IR readability, and makes - querying the annotations more direct. -* Various constant instructions are represented by the same `spv.constant` - op. Those instructions are just for constants of different types; using one - op to represent them reduces IR verbosity and makes transformations less - tedious. +* Requirements for capabilities, extensions, extended instruction sets, + addressing model, and memory model is conveyed using `spv.module` + attributes. This is considered better because these information are for the + exexcution environment. It's eaiser to probe them if on the module op + itself. +* Annotations/decoration instrutions are "folded" into the instructions they + decorate and represented as attributes on those ops. This elimiates + potential forward references of SSA values, improves IR readability, and + makes querying the annotations more direct. +* Types are represented using MLIR standard types and SPIR-V dialect specific + types. There are no type declaration ops in the SPIR-V dialect. +* Various normal constant instructions are represented by the same + `spv.constant` op. Those instructions are just for constants of different + types; using one op to represent them reduces IR verbosity and makes + transformations less tedious. +* Normal constants are not placed in `spv.module`'s region; they are localized + into functions. This is to make functions in the SPIR-V dialect to be + isolated and explicit capturing. +* Global variables are defined with the `spv.globalVariable` op. They do not + generate SSA values. Instead they have symbols and should be referenced via + symbols. To use a global variables in a function block, `spv._address_of` is + needed to turn the symbol into a SSA value. +* Specialization constants are defined with the `spv.specConstant` op. Similar + to global variables, they do not generate SSA values and have symbols for + reference, too. `spv._reference_of` is needed to turn the symbol into a SSA + value for use in a function block. ## Types @@ -170,6 +192,121 @@ For Example, !spv.struct ``` +## Function + +A SPIR-V function is defined using the builtin `func` op. `spv.module` verifies +that the functions inside it comply with SPIR-V requirements: at most one +result, no nested functions, and so on. + +## Control Flow + +SPIR-V binary format uses merge instructions (`OpSelectionMerge` and +`OpLoopMerge`) to declare structured control flow. They explicitly declare a +header block before the control flow diverges and a merge block where control +flow subsequently converges. These blocks delimit constructs that must nest, and +can only be entered and exited in structured ways. + +In the SPIR-V dialect, we use regions to mark the boundary of a structured +control flow construct. With this approach, it's easier to discover all blocks +belonging to a structured control flow construct. It is also more idiomatic to +MLIR system. + +We introduce a a `spv.loop` op for structured loops. The merge targets are the +next ops following them. Inside their regions, a special terminator, +`spv._merge` is introduced for branching to the merge target. + +### Loop + +`spv.loop` defines a loop construct. It contains one region. The `spv.loop` +region should contain at least four blocks: one entry block, one loop header +block, one loop continue block, one merge block. + +* The entry block should be the first block and it should jump to the loop + header block, which is the second block. +* The merge block should be the last block. The merge block should only + contain a `spv._merge` op. Any block except the entry block can branch to + the merge block for early exit. +* The continue block should be the second to last block and it should have a + branch to the loop header block. +* The loop continue block should be the only block, except the entry block, + branching to the loop header block. + +``` + +-------------+ + | entry block | (one outgoing branch) + +-------------+ + | + v + +-------------+ (two incoming branches) + | loop header | <-----+ (may have one or two outgoing branches) + +-------------+ | + | + ... | + \ | / | + v | + +---------------+ | (may have multiple incoming branches) + | loop continue | -----+ (may have one or two outgoing branches) + +---------------+ + + ... + \ | / + v + +-------------+ (may have mulitple incoming branches) + | merge block | + +-------------+ +``` + +The reason to have another entry block instead of directly using the loop header +block as the entry block is to satisfy region's requirement: entry block of +region may not have predecessors. We have a merge block so that branch ops can +reference it as successors. The loop continue block here corresponds to +"continue construct" using SPIR-V spec's term; it does not mean the "continue +block" as defined in the SPIR-V spec, which is "a block containing a branch to +an OpLoopMerge instruction’s Continue Target." + +For example, for the given function + +```c++ +void loop(int count) { + for (int i = 0; i < count; ++i) { + // ... + } +} +``` + +It will be represented as + +```mlir +func @loop(%count : i32) -> () { + %zero = spv.constant 0: i32 + %one = spv.constant 1: i32 + %var = spv.Variable init(%zero) : !spv.ptr + + spv.loop { + spv.Branch ^header + + ^header: + %val0 = spv.Load "Function" %var : i32 + %cmp = spv.SLessThan %val0, %count : i32 + spv.BranchConditional %cmp, ^body, ^merge + + ^body: + // ... + spv.Branch ^continue + + ^continue: + %val1 = spv.Load "Function" %var : i32 + %add = spv.IAdd %val1, %one : i32 + spv.Store "Function" %var, %add : i32 + spv.Branch ^header + + ^merge: + spv._merge + } + return +} +``` + ## Serialization The serialization library provides two entry points, `mlir::spirv::serialize()` diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 7dea586f9191..0accb05f0ac4 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -144,6 +144,7 @@ def SPV_OC_OpFOrdLessThanEqual : I32EnumAttrCase<"OpFOrdLessThanEqual", 188 def SPV_OC_OpFUnordLessThanEqual : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>; def SPV_OC_OpFOrdGreaterThanEqual : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>; def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>; +def SPV_OC_OpLoopMerge : I32EnumAttrCase<"OpLoopMerge", 246>; def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>; def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>; @@ -173,9 +174,9 @@ def SPV_OpcodeAttr : SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, - SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpLabel, - SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue + SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, + SPV_OC_OpLoopMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, + SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue ]> { let returnType = "::mlir::spirv::Opcode"; let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; @@ -924,6 +925,28 @@ def SPV_LinkageTypeAttr : let cppNamespace = "::mlir::spirv"; } +def SPV_LC_None : I32EnumAttrCase<"None", 0x0000>; +def SPV_LC_Unroll : I32EnumAttrCase<"Unroll", 0x0001>; +def SPV_LC_DontUnroll : I32EnumAttrCase<"DontUnroll", 0x0002>; +def SPV_LC_DependencyInfinite : I32EnumAttrCase<"DependencyInfinite", 0x0004>; +def SPV_LC_DependencyLength : I32EnumAttrCase<"DependencyLength", 0x0008>; +def SPV_LC_MinIterations : I32EnumAttrCase<"MinIterations", 0x0010>; +def SPV_LC_MaxIterations : I32EnumAttrCase<"MaxIterations", 0x0020>; +def SPV_LC_IterationMultiple : I32EnumAttrCase<"IterationMultiple", 0x0040>; +def SPV_LC_PeelCount : I32EnumAttrCase<"PeelCount", 0x0080>; +def SPV_LC_PartialCount : I32EnumAttrCase<"PartialCount", 0x0100>; + +def SPV_LoopControlAttr : + I32EnumAttr<"LoopControl", "valid SPIR-V LoopControl", [ + SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite, + SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations, + SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount + ]> { + let returnType = "::mlir::spirv::LoopControl"; + let convertFromStorage = "static_cast<::mlir::spirv::LoopControl>($_self.getInt())"; + let cppNamespace = "::mlir::spirv"; +} + def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>; def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>; def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 0927684732f0..ffefa143c7c9 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -137,6 +137,71 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> { // ----- +def SPV_LoopOp : SPV_Op<"loop"> { + let summary = "Define a structured loop."; + + let description = [{ + SPIR-V can explicitly declare structured control-flow constructs using merge + instructions. These explicitly declare a header block before the control + flow diverges and a merge block where control flow subsequently converges. + These blocks delimit constructs that must nest, and can only be entered + and exited in structured ways. See "2.11. Structured Control Flow" of the + SPIR-V spec for more details. + + Instead of having a `spv.LoopMerge` op to directly model loop merge + instruction for indicating the merge and continue target, we use regions + to delimit the boundary of the loop: the merge target is the next op + following the `spv.loop` op and the continue target is the block that + has a back-edge pointing to the entry block inside the `spv.loop`'s region. + This way it's easier to discover all blocks belonging to a construct and + it plays nicer with the MLIR system. + + The `spv.loop` region should contain at least four blocks: one entry block, + one loop header block, one loop continue block, one loop merge block. + The entry block should be the first block and it should jump to the loop + header block, which is the second block. The loop merge block should be the + last block. The merge block should only contain a `spv._merge` op. + The continue block should be the second to last block and it should have a + branch to the loop header block. The loop continue block should be the only + block, except the entry block, branching to the header block. + }]; + + let arguments = (ins + SPV_LoopControlAttr:$loop_control + ); + + let results = (outs); + + let regions = (region AnyRegion:$body); + + let hasOpcode = 0; +} + +// ----- + +def SPV_MergeOp : SPV_Op<"_merge", [HasParent<"LoopOp">, Terminator]> { + let summary = "A special terminator for merging a structured selection/loop."; + + let description = [{ + We use `spv.selection`/`spv.loop` for modelling structured selection/loop. + This op is a terminator used inside their regions to mean jumping to the + merge point, which is the next op following the `spv.selection` or + `spv.loop` op. This op does not have a corresponding instruction in the + SPIR-V binary format; it's solely for structural purpose. + }]; + + let arguments = (ins); + + let results = (outs); + + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; + + let hasOpcode = 0; +} + +// ----- + def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> { let summary = "Return with no value from a function with void return type."; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 3338c104a7ce..0873eb0c9a01 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1017,6 +1017,143 @@ static LogicalResult verify(spirv::LoadOp loadOp) { return verifyMemoryAccessAttribute(loadOp); } +//===----------------------------------------------------------------------===// +// spv.loop +//===----------------------------------------------------------------------===// + +static ParseResult parseLoopOp(OpAsmParser *parser, OperationState *state) { + // TODO(antiagainst): support loop control properly + Builder builder = parser->getBuilder(); + state->addAttribute("loop_control", + builder.getI32IntegerAttr( + static_cast(spirv::LoopControl::None))); + + return parser->parseRegion(*state->addRegion(), /*arguments=*/{}, + /*argTypes=*/{}); +} + +static void print(spirv::LoopOp loopOp, OpAsmPrinter *printer) { + auto *op = loopOp.getOperation(); + + *printer << spirv::LoopOp::getOperationName(); + printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} + +/// Returns true if the given `block` only contains one `spv._merge` op. +static inline bool isMergeBlock(Block &block) { + return std::next(block.begin()) == block.end() && + isa(block.front()); +} + +/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the +/// given `dstBlock`. +static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { + // Check that there is only one op in the `srcBlock`. + if (std::next(srcBlock.begin()) != srcBlock.end()) + return false; + + auto branchOp = dyn_cast(srcBlock.back()); + return branchOp && branchOp.getSuccessor(0) == &dstBlock; +} + +static LogicalResult verify(spirv::LoopOp loopOp) { + auto *op = loopOp.getOperation(); + + // We need to verify that the blocks follow the following layout: + // + // +-------------+ + // | entry block | + // +-------------+ + // | + // v + // +-------------+ + // | loop header | <-----+ + // +-------------+ | + // | + // ... | + // \ | / | + // v | + // +---------------+ | + // | loop continue | -----+ + // +---------------+ + // + // ... + // \ | / + // v + // +-------------+ + // | merge block | + // +-------------+ + + auto ®ion = op->getRegion(0); + // Allow empty region as a degenerated case, which can come from + // optimizations. + if (region.empty()) + return success(); + + // The last block is the merge block. + Block &merge = region.back(); + if (!isMergeBlock(merge)) + return loopOp.emitOpError( + "last block must be the merge block with only one 'spv._merge' op"); + + if (std::next(region.begin()) == region.end()) + return loopOp.emitOpError( + "must have an entry block branching to the loop header block"); + // The first block is the entry block. + Block &entry = region.front(); + + if (std::next(region.begin(), 2) == region.end()) + return loopOp.emitOpError( + "must have a loop header block branched from the entry block"); + // The second block is the loop header block. + Block &header = *std::next(region.begin(), 1); + + if (!hasOneBranchOpTo(entry, header)) + return loopOp.emitOpError( + "entry block must only have one 'spv.Branch' op to the second block"); + + if (std::next(region.begin(), 3) == region.end()) + return loopOp.emitOpError( + "requires a loop continue block branching to the loop header block"); + // The second to last block is the loop continue block. + Block &cont = *std::prev(region.end(), 2); + + // Make sure that we have a branch from the loop continue block to the loop + // header block. + if (llvm::none_of( + llvm::seq(0, cont.getNumSuccessors()), + [&](unsigned index) { return cont.getSuccessor(index) == &header; })) + return loopOp.emitOpError("second to last block must be the loop continue " + "block that branches to the loop header block"); + + // Make sure that no other blocks (except the entry and loop continue block) + // branches to the loop header block. + for (auto &block : llvm::make_range(std::next(region.begin(), 2), + std::prev(region.end(), 2))) { + for (auto i : llvm::seq(0, block.getNumSuccessors())) { + if (block.getSuccessor(i) == &header) { + return loopOp.emitOpError("can only have the entry and loop continue " + "block branching to the loop header block"); + } + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// spv._merge +//===----------------------------------------------------------------------===// + +static LogicalResult verify(spirv::MergeOp mergeOp) { + Block &parentLastBlock = mergeOp.getParentRegion()->back(); + if (mergeOp.getOperation() != parentLastBlock.getTerminator()) + return mergeOp.emitOpError( + "can only be used in the last block of 'spv.loop'"); + return success(); +} + //===----------------------------------------------------------------------===// // spv.module //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir index 11b8c9f3d4a9..8199d0ea6102 100644 --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -144,6 +144,190 @@ func @weights_cannot_both_be_zero() -> () { // ----- +//===----------------------------------------------------------------------===// +// spv.loop +//===----------------------------------------------------------------------===// + +// for (int i = 0; i < count; ++i) {} +func @loop(%count : i32) -> () { + %zero = spv.constant 0: i32 + %one = spv.constant 1: i32 + %var = spv.Variable init(%zero) : !spv.ptr + + // CHECK: spv.loop { + spv.loop { + // CHECK-NEXT: spv.Branch ^bb1 + spv.Branch ^header + + // CHECK-NEXT: ^bb1: + ^header: + %val0 = spv.Load "Function" %var : i32 + %cmp = spv.SLessThan %val0, %count : i32 + // CHECK: spv.BranchConditional %{{.*}}, ^bb2, ^bb4 + spv.BranchConditional %cmp, ^body, ^merge + + // CHECK-NEXT: ^bb2: + ^body: + // Do nothing + // CHECK-NEXT: spv.Branch ^bb3 + spv.Branch ^continue + + // CHECK-NEXT: ^bb3: + ^continue: + %val1 = spv.Load "Function" %var : i32 + %add = spv.IAdd %val1, %one : i32 + spv.Store "Function" %var, %add : i32 + // CHECK: spv.Branch ^bb1 + spv.Branch ^header + + // CHECK-NEXT: ^bb4: + ^merge: + spv._merge + } + return +} + +// ----- + +// CHECK-LABEL: @empty_region +func @empty_region() -> () { + // CHECK: spv.loop + spv.loop { + } + return +} + +// ----- + +func @wrong_merge_block() -> () { + // expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}} + spv.loop { + spv.Return + } + return +} + +// ----- + +func @missing_entry_block() -> () { + // expected-error @+1 {{must have an entry block branching to the loop header block}} + spv.loop { + spv._merge + } + return +} + +// ----- + +func @missing_header_block() -> () { + // expected-error @+1 {{must have a loop header block branched from the entry block}} + spv.loop { + ^entry: + spv.Branch ^merge + ^merge: + spv._merge + } + return +} + +// ----- + +func @entry_should_branch_to_header() -> () { + // expected-error @+1 {{entry block must only have one 'spv.Branch' op to the second block}} + spv.loop { + ^entry: + spv.Branch ^merge + ^header: + spv.Branch ^merge + ^merge: + spv._merge + } + return +} + +// ----- + +func @missing_continue_block() -> () { + // expected-error @+1 {{requires a loop continue block branching to the loop header block}} + spv.loop { + ^entry: + spv.Branch ^header + ^header: + spv.Branch ^merge + ^merge: + spv._merge + } + return +} + +// ----- + +func @continue_should_branch_to_header() -> () { + // expected-error @+1 {{second to last block must be the loop continue block that branches to the loop header block}} + spv.loop { + ^entry: + spv.Branch ^header + ^header: + spv.Branch ^continue + ^continue: + spv.Branch ^merge + ^merge: + spv._merge + } + return +} + +// ----- + +func @only_entry_and_continue_branch_to_header() -> () { + // expected-error @+1 {{can only have the entry and loop continue block branching to the loop header block}} + spv.loop { + ^entry: + spv.Branch ^header + ^header: + spv.Branch ^cont1 + ^cont1: + spv.Branch ^header + ^cont2: + spv.Branch ^header + ^merge: + spv._merge + } + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.merge +//===----------------------------------------------------------------------===// + +func @merge() -> () { + // expected-error @+1 {{expects parent op 'spv.loop'}} + spv._merge +} + +// ----- + +func @only_allowed_in_last_block() -> () { + %true = spv.constant true + spv.loop { + spv.Branch ^header + ^header: + spv.BranchConditional %true, ^body, ^merge + ^body: + // expected-error @+1 {{can only be used in the last block of 'spv.loop'}} + spv._merge + ^continue: + spv.Branch ^header + ^merge: + spv._merge + } + return +} + +// ----- + //===----------------------------------------------------------------------===// // spv.Return //===----------------------------------------------------------------------===//