[MLIR][LLVMDialect] SelectionOp conversion pattern

This patch introduces conversion pattern for `spv.selection` op.
The conversion can only be applied to selection with all blocks being
reachable. Moreover, selection with control attributes "Flatten" and
"DontFlatten" is not supported.
Since the `PatternRewriter` hook for block merging has not been implemented
for `ConversionPatternRewriter`, merge and continue blocks are kept
separately.

Reviewed By: antiagainst, ftynse

Differential Revision: https://reviews.llvm.org/D83860
This commit is contained in:
George Mitenkov 2020-07-21 16:45:36 +03:00
parent 7b5bddfd03
commit 61dd481f11
2 changed files with 169 additions and 0 deletions

View File

@ -613,6 +613,84 @@ public:
}
};
class MergePattern : public SPIRVToLLVMConversion<spirv::MergeOp> {
public:
using SPIRVToLLVMConversion<spirv::MergeOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::MergeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
/// Converts `spv.selection` with `spv.BranchConditional` in its header block.
/// All blocks within selection should be reachable for conversion to succeed.
class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
public:
using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// There is no support for `Flatten` or `DontFlatten` selection control at
// the moment. This are just compiler hints and can be performed during the
// optimization passes.
if (op.selection_control() != spirv::SelectionControl::None)
return failure();
// `spv.selection` should have at least two blocks: one selection header
// block and one merge block. If no blocks are present, or control flow
// branches straight to merge block (two blocks are present), the op is
// redundant and it is erased.
if (op.body().getBlocks().size() <= 2) {
rewriter.eraseOp(op);
return success();
}
Location loc = op.getLoc();
// Split the current block after `spv.selection`. The remaing ops will be
// used in `continueBlock`.
auto *currentBlock = rewriter.getInsertionBlock();
rewriter.setInsertionPointAfter(op);
auto position = rewriter.getInsertionPoint();
auto *continueBlock = rewriter.splitBlock(currentBlock, position);
// Extract conditional branch information from the header block. By SPIR-V
// dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
// op. Note that `spv.Switch op` is not supported at the moment in the
// SPIR-V dialect. Remove this block when finished.
auto *headerBlock = op.getHeaderBlock();
assert(headerBlock->getOperations().size() == 1);
auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
headerBlock->getOperations().front());
if (!condBrOp)
return failure();
rewriter.eraseBlock(headerBlock);
// Branch from merge block to continue block.
auto *mergeBlock = op.getMergeBlock();
Operation *terminator = mergeBlock->getTerminator();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(mergeBlock);
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
// Link current block to `true` and `false` blocks within the selection.
Block *trueBlock = condBrOp.getTrueBlock();
Block *falseBlock = condBrOp.getFalseBlock();
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
condBrOp.trueTargetOperands(), falseBlock,
condBrOp.falseTargetOperands());
rewriter.inlineRegionBefore(op.body(), continueBlock);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}
};
/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
/// puts a restriction on `Shift` and `Base` to have the same bit width,
/// `Shift` is zero or sign extended to match this specification. Cases when
@ -843,6 +921,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
// Control Flow ops
BranchConversionPattern, BranchConditionalConversionPattern,
SelectionPattern, MergePattern,
// Function Call op
FunctionCallPattern,

View File

@ -80,3 +80,93 @@ spv.module Logical GLSL450 {
spv.Return
}
}
// -----
//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//
spv.module Logical GLSL450 {
spv.func @selection_empty() -> () "None" {
// CHECK: llvm.return
spv.selection {
}
spv.Return
}
spv.func @selection_with_merge_block_only() -> () "None" {
%cond = spv.constant true
// CHECK: llvm.return
spv.selection {
spv.BranchConditional %cond, ^merge, ^merge
^merge:
spv._merge
}
spv.Return
}
spv.func @selection_with_true_block_only() -> () "None" {
// CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1
%cond = spv.constant true
// CHECK: llvm.cond_br %[[COND]], ^bb1, ^bb2
spv.selection {
spv.BranchConditional %cond, ^true, ^merge
// CHECK: ^bb1:
^true:
// CHECK: llvm.br ^bb2
spv.Branch ^merge
// CHECK: ^bb2:
^merge:
// CHECK: llvm.br ^bb3
spv._merge
}
// CHECK: ^bb3:
// CHECK-NEXT: llvm.return
spv.Return
}
spv.func @selection_with_both_true_and_false_block() -> () "None" {
// CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1
%cond = spv.constant true
// CHECK: llvm.cond_br %[[COND]], ^bb1, ^bb2
spv.selection {
spv.BranchConditional %cond, ^true, ^false
// CHECK: ^bb1:
^true:
// CHECK: llvm.br ^bb3
spv.Branch ^merge
// CHECK: ^bb2:
^false:
// CHECK: llvm.br ^bb3
spv.Branch ^merge
// CHECK: ^bb3:
^merge:
// CHECK: llvm.br ^bb4
spv._merge
}
// CHECK: ^bb4:
// CHECK-NEXT: llvm.return
spv.Return
}
spv.func @selection_with_early_return(%arg0: i1) -> i32 "None" {
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
%0 = spv.constant 0 : i32
// CHECK: llvm.cond_br %{{.*}}, ^bb1(%[[ZERO]] : !llvm.i32), ^bb2
spv.selection {
spv.BranchConditional %arg0, ^true(%0 : i32), ^merge
// CHECK: ^bb1(%[[ARG:.*]]: !llvm.i32):
^true(%arg1: i32):
// CHECK: llvm.return %[[ARG]] : !llvm.i32
spv.ReturnValue %arg1 : i32
// CHECK: ^bb2:
^merge:
// CHECK: llvm.br ^bb3
spv._merge
}
// CHECK: ^bb3:
%one = spv.constant 1 : i32
spv.ReturnValue %one : i32
}
}