forked from OSchip/llvm-project
[MLIR][SPIRV] Properly (de-)serialize BranchConditionalOp.
Implements proper (de-)serialization logic for BranchConditionalOp when such ops have true/false target operands. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D101602
This commit is contained in:
parent
a95473c563
commit
e4dee7e730
|
@ -1573,7 +1573,8 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
|
|||
for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
|
||||
uint32_t value = operands[i];
|
||||
Block *predecessor = getOrCreateBlock(operands[i + 1]);
|
||||
blockPhiInfo[predecessor].push_back(value);
|
||||
std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
|
||||
blockPhiInfo[predecessorTargetPair].push_back(value);
|
||||
LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
|
||||
<< " with arg id = " << value << '\n');
|
||||
}
|
||||
|
@ -1853,7 +1854,8 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
|
|||
OpBuilder::InsertionGuard guard(opBuilder);
|
||||
|
||||
for (const auto &info : blockPhiInfo) {
|
||||
Block *block = info.first;
|
||||
Block *block = info.first.first;
|
||||
Block *target = info.first.second;
|
||||
const BlockPhiInfo &phiInfo = info.second;
|
||||
LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
|
||||
LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
|
||||
|
@ -1882,6 +1884,24 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
|
|||
opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
|
||||
blockArgs);
|
||||
branchOp.erase();
|
||||
} else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
|
||||
assert((branchCondOp.getTrueBlock() == target ||
|
||||
branchCondOp.getFalseBlock() == target) &&
|
||||
"expected target to be either the true or false target");
|
||||
if (target == branchCondOp.trueTarget())
|
||||
opBuilder.create<spirv::BranchConditionalOp>(
|
||||
branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
|
||||
branchCondOp.getFalseBlockArguments(),
|
||||
branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
|
||||
branchCondOp.falseTarget());
|
||||
else
|
||||
opBuilder.create<spirv::BranchConditionalOp>(
|
||||
branchCondOp.getLoc(), branchCondOp.condition(),
|
||||
branchCondOp.getTrueBlockArguments(), blockArgs,
|
||||
branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
|
||||
branchCondOp.getFalseBlock());
|
||||
|
||||
branchCondOp.erase();
|
||||
} else {
|
||||
return emitError(unknownLoc, "unimplemented terminator for Phi creation");
|
||||
}
|
||||
|
|
|
@ -560,8 +560,10 @@ private:
|
|||
// Header block to its merge (and continue) target mapping.
|
||||
BlockMergeInfoMap blockMergeInfo;
|
||||
|
||||
// Block to its phi (block argument) mapping.
|
||||
DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
|
||||
// For each pair of {predecessor, target} blocks, maps the pair of blocks to
|
||||
// the list of phi arguments passed from predecessor to target.
|
||||
DenseMap<std::pair<Block * /*predecessor*/, Block * /*target*/>, BlockPhiInfo>
|
||||
blockPhiInfo;
|
||||
|
||||
// Result <id> to value mapping.
|
||||
DenseMap<uint32_t, Value> valueMap;
|
||||
|
|
|
@ -959,7 +959,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
|
|||
// OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
|
||||
// So we need to collect all predecessor blocks and the arguments they send
|
||||
// to this block.
|
||||
SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
|
||||
SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
|
||||
for (Block *predecessor : block->getPredecessors()) {
|
||||
auto *terminator = predecessor->getTerminator();
|
||||
// The predecessor here is the immediate one according to MLIR's IR
|
||||
|
@ -971,7 +971,21 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
|
|||
// structured control flow op's merge block.
|
||||
predecessor = getPhiIncomingBlock(predecessor);
|
||||
if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
|
||||
predecessors.emplace_back(predecessor, branchOp.operand_begin());
|
||||
predecessors.emplace_back(predecessor, branchOp.getOperands());
|
||||
} else if (auto branchCondOp =
|
||||
dyn_cast<spirv::BranchConditionalOp>(terminator)) {
|
||||
Optional<OperandRange> blockOperands;
|
||||
|
||||
for (auto successorIdx :
|
||||
llvm::seq<unsigned>(0, predecessor->getNumSuccessors()))
|
||||
if (predecessor->getSuccessors()[successorIdx] == block) {
|
||||
blockOperands = branchCondOp.getSuccessorOperands(successorIdx);
|
||||
break;
|
||||
}
|
||||
|
||||
assert(blockOperands && !blockOperands->empty() &&
|
||||
"expected non-empty block operand range");
|
||||
predecessors.emplace_back(predecessor, *blockOperands);
|
||||
} else {
|
||||
return terminator->emitError("unimplemented terminator for Phi creation");
|
||||
}
|
||||
|
@ -996,7 +1010,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
|
|||
phiArgs.push_back(phiID);
|
||||
|
||||
for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
|
||||
Value value = *(predecessors[predIndex].second + argIndex);
|
||||
Value value = predecessors[predIndex].second[argIndex];
|
||||
uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
|
||||
LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
|
||||
<< ") value " << value << ' ');
|
||||
|
|
|
@ -286,3 +286,60 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
|||
spv.EntryPoint "GLCompute" @fmul_kernel
|
||||
spv.ExecutionMode @fmul_kernel "LocalSize", 32, 1, 1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
// CHECK-LABEL: @cond_branch_true_argument
|
||||
spv.func @cond_branch_true_argument() -> () "None" {
|
||||
%true = spv.Constant true
|
||||
%zero = spv.Constant 0 : i32
|
||||
%one = spv.Constant 1 : i32
|
||||
// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}}, %{{.*}} : i32, i32), ^[[false1:.*]]
|
||||
spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
|
||||
// CHECK: [[true1]](%{{.*}}: i32, %{{.*}}: i32)
|
||||
^true1(%arg0: i32, %arg1: i32):
|
||||
spv.Return
|
||||
// CHECK: [[false1]]:
|
||||
^false1:
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
// CHECK-LABEL: @cond_branch_false_argument
|
||||
spv.func @cond_branch_false_argument() -> () "None" {
|
||||
%true = spv.Constant true
|
||||
%zero = spv.Constant 0 : i32
|
||||
%one = spv.Constant 1 : i32
|
||||
// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]], ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
|
||||
spv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32)
|
||||
// CHECK: [[true1]]:
|
||||
^true1:
|
||||
spv.Return
|
||||
// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
|
||||
^false1(%arg0: i32, %arg1: i32):
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
// CHECK-LABEL: @cond_branch_true_and_false_argument
|
||||
spv.func @cond_branch_true_and_false_argument() -> () "None" {
|
||||
%true = spv.Constant true
|
||||
%zero = spv.Constant 0 : i32
|
||||
%one = spv.Constant 1 : i32
|
||||
// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}} : i32), ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
|
||||
spv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32)
|
||||
// CHECK: [[true1]](%{{.*}}: i32):
|
||||
^true1(%arg0: i32):
|
||||
spv.Return
|
||||
// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
|
||||
^false1(%arg1: i32, %arg2: i32):
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue