forked from OSchip/llvm-project
[MLIR][LLVMDialect] SelectionOp conversion pattern
This patch introduces conversion pattern for `spv.selection` op. The conversion can only be applied to selection with all blocks being reachable. Moreover, selection with control attributes "Flatten" and "DontFlatten" is not supported. Since the `PatternRewriter` hook for block merging has not been implemented for `ConversionPatternRewriter`, merge and continue blocks are kept separately. Reviewed By: antiagainst, ftynse Differential Revision: https://reviews.llvm.org/D83860
This commit is contained in:
parent
7b5bddfd03
commit
61dd481f11
|
@ -613,6 +613,84 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class MergePattern : public SPIRVToLLVMConversion<spirv::MergeOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<spirv::MergeOp>::SPIRVToLLVMConversion;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::MergeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts `spv.selection` with `spv.BranchConditional` in its header block.
|
||||
/// All blocks within selection should be reachable for conversion to succeed.
|
||||
class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// There is no support for `Flatten` or `DontFlatten` selection control at
|
||||
// the moment. This are just compiler hints and can be performed during the
|
||||
// optimization passes.
|
||||
if (op.selection_control() != spirv::SelectionControl::None)
|
||||
return failure();
|
||||
|
||||
// `spv.selection` should have at least two blocks: one selection header
|
||||
// block and one merge block. If no blocks are present, or control flow
|
||||
// branches straight to merge block (two blocks are present), the op is
|
||||
// redundant and it is erased.
|
||||
if (op.body().getBlocks().size() <= 2) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Split the current block after `spv.selection`. The remaing ops will be
|
||||
// used in `continueBlock`.
|
||||
auto *currentBlock = rewriter.getInsertionBlock();
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
auto position = rewriter.getInsertionPoint();
|
||||
auto *continueBlock = rewriter.splitBlock(currentBlock, position);
|
||||
|
||||
// Extract conditional branch information from the header block. By SPIR-V
|
||||
// dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
|
||||
// op. Note that `spv.Switch op` is not supported at the moment in the
|
||||
// SPIR-V dialect. Remove this block when finished.
|
||||
auto *headerBlock = op.getHeaderBlock();
|
||||
assert(headerBlock->getOperations().size() == 1);
|
||||
auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
|
||||
headerBlock->getOperations().front());
|
||||
if (!condBrOp)
|
||||
return failure();
|
||||
rewriter.eraseBlock(headerBlock);
|
||||
|
||||
// Branch from merge block to continue block.
|
||||
auto *mergeBlock = op.getMergeBlock();
|
||||
Operation *terminator = mergeBlock->getTerminator();
|
||||
ValueRange terminatorOperands = terminator->getOperands();
|
||||
rewriter.setInsertionPointToEnd(mergeBlock);
|
||||
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
|
||||
|
||||
// Link current block to `true` and `false` blocks within the selection.
|
||||
Block *trueBlock = condBrOp.getTrueBlock();
|
||||
Block *falseBlock = condBrOp.getFalseBlock();
|
||||
rewriter.setInsertionPointToEnd(currentBlock);
|
||||
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
|
||||
condBrOp.trueTargetOperands(), falseBlock,
|
||||
condBrOp.falseTargetOperands());
|
||||
|
||||
rewriter.inlineRegionBefore(op.body(), continueBlock);
|
||||
rewriter.replaceOp(op, continueBlock->getArguments());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
|
||||
/// puts a restriction on `Shift` and `Base` to have the same bit width,
|
||||
/// `Shift` is zero or sign extended to match this specification. Cases when
|
||||
|
@ -843,6 +921,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
|
|||
|
||||
// Control Flow ops
|
||||
BranchConversionPattern, BranchConditionalConversionPattern,
|
||||
SelectionPattern, MergePattern,
|
||||
|
||||
// Function Call op
|
||||
FunctionCallPattern,
|
||||
|
|
|
@ -80,3 +80,93 @@ spv.module Logical GLSL450 {
|
|||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.selection
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @selection_empty() -> () "None" {
|
||||
// CHECK: llvm.return
|
||||
spv.selection {
|
||||
}
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @selection_with_merge_block_only() -> () "None" {
|
||||
%cond = spv.constant true
|
||||
// CHECK: llvm.return
|
||||
spv.selection {
|
||||
spv.BranchConditional %cond, ^merge, ^merge
|
||||
^merge:
|
||||
spv._merge
|
||||
}
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @selection_with_true_block_only() -> () "None" {
|
||||
// CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1
|
||||
%cond = spv.constant true
|
||||
// CHECK: llvm.cond_br %[[COND]], ^bb1, ^bb2
|
||||
spv.selection {
|
||||
spv.BranchConditional %cond, ^true, ^merge
|
||||
// CHECK: ^bb1:
|
||||
^true:
|
||||
// CHECK: llvm.br ^bb2
|
||||
spv.Branch ^merge
|
||||
// CHECK: ^bb2:
|
||||
^merge:
|
||||
// CHECK: llvm.br ^bb3
|
||||
spv._merge
|
||||
}
|
||||
// CHECK: ^bb3:
|
||||
// CHECK-NEXT: llvm.return
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @selection_with_both_true_and_false_block() -> () "None" {
|
||||
// CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1
|
||||
%cond = spv.constant true
|
||||
// CHECK: llvm.cond_br %[[COND]], ^bb1, ^bb2
|
||||
spv.selection {
|
||||
spv.BranchConditional %cond, ^true, ^false
|
||||
// CHECK: ^bb1:
|
||||
^true:
|
||||
// CHECK: llvm.br ^bb3
|
||||
spv.Branch ^merge
|
||||
// CHECK: ^bb2:
|
||||
^false:
|
||||
// CHECK: llvm.br ^bb3
|
||||
spv.Branch ^merge
|
||||
// CHECK: ^bb3:
|
||||
^merge:
|
||||
// CHECK: llvm.br ^bb4
|
||||
spv._merge
|
||||
}
|
||||
// CHECK: ^bb4:
|
||||
// CHECK-NEXT: llvm.return
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @selection_with_early_return(%arg0: i1) -> i32 "None" {
|
||||
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||
%0 = spv.constant 0 : i32
|
||||
// CHECK: llvm.cond_br %{{.*}}, ^bb1(%[[ZERO]] : !llvm.i32), ^bb2
|
||||
spv.selection {
|
||||
spv.BranchConditional %arg0, ^true(%0 : i32), ^merge
|
||||
// CHECK: ^bb1(%[[ARG:.*]]: !llvm.i32):
|
||||
^true(%arg1: i32):
|
||||
// CHECK: llvm.return %[[ARG]] : !llvm.i32
|
||||
spv.ReturnValue %arg1 : i32
|
||||
// CHECK: ^bb2:
|
||||
^merge:
|
||||
// CHECK: llvm.br ^bb3
|
||||
spv._merge
|
||||
}
|
||||
// CHECK: ^bb3:
|
||||
%one = spv.constant 1 : i32
|
||||
spv.ReturnValue %one : i32
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue