[MLIR][SPIRV] Properly (de-)serialize BranchConditionalOp.

Implements proper (de-)serialization logic for BranchConditionalOp when
such ops have true/false target operands.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D101602
This commit is contained in:
KareemErgawy-TomTom 2021-05-07 08:59:35 +02:00
parent a95473c563
commit e4dee7e730
4 changed files with 100 additions and 7 deletions

View File

@ -1573,7 +1573,8 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
uint32_t value = operands[i];
Block *predecessor = getOrCreateBlock(operands[i + 1]);
blockPhiInfo[predecessor].push_back(value);
std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
blockPhiInfo[predecessorTargetPair].push_back(value);
LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
<< " with arg id = " << value << '\n');
}
@ -1853,7 +1854,8 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
OpBuilder::InsertionGuard guard(opBuilder);
for (const auto &info : blockPhiInfo) {
Block *block = info.first;
Block *block = info.first.first;
Block *target = info.first.second;
const BlockPhiInfo &phiInfo = info.second;
LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
@ -1882,6 +1884,24 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
blockArgs);
branchOp.erase();
} else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
assert((branchCondOp.getTrueBlock() == target ||
branchCondOp.getFalseBlock() == target) &&
"expected target to be either the true or false target");
if (target == branchCondOp.trueTarget())
opBuilder.create<spirv::BranchConditionalOp>(
branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
branchCondOp.getFalseBlockArguments(),
branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
branchCondOp.falseTarget());
else
opBuilder.create<spirv::BranchConditionalOp>(
branchCondOp.getLoc(), branchCondOp.condition(),
branchCondOp.getTrueBlockArguments(), blockArgs,
branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
branchCondOp.getFalseBlock());
branchCondOp.erase();
} else {
return emitError(unknownLoc, "unimplemented terminator for Phi creation");
}

View File

@ -560,8 +560,10 @@ private:
// Header block to its merge (and continue) target mapping.
BlockMergeInfoMap blockMergeInfo;
// Block to its phi (block argument) mapping.
DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
// For each pair of {predecessor, target} blocks, maps the pair of blocks to
// the list of phi arguments passed from predecessor to target.
DenseMap<std::pair<Block * /*predecessor*/, Block * /*target*/>, BlockPhiInfo>
blockPhiInfo;
// Result <id> to value mapping.
DenseMap<uint32_t, Value> valueMap;

View File

@ -959,7 +959,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
// OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
// So we need to collect all predecessor blocks and the arguments they send
// to this block.
SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
for (Block *predecessor : block->getPredecessors()) {
auto *terminator = predecessor->getTerminator();
// The predecessor here is the immediate one according to MLIR's IR
@ -971,7 +971,21 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
// structured control flow op's merge block.
predecessor = getPhiIncomingBlock(predecessor);
if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
predecessors.emplace_back(predecessor, branchOp.operand_begin());
predecessors.emplace_back(predecessor, branchOp.getOperands());
} else if (auto branchCondOp =
dyn_cast<spirv::BranchConditionalOp>(terminator)) {
Optional<OperandRange> blockOperands;
for (auto successorIdx :
llvm::seq<unsigned>(0, predecessor->getNumSuccessors()))
if (predecessor->getSuccessors()[successorIdx] == block) {
blockOperands = branchCondOp.getSuccessorOperands(successorIdx);
break;
}
assert(blockOperands && !blockOperands->empty() &&
"expected non-empty block operand range");
predecessors.emplace_back(predecessor, *blockOperands);
} else {
return terminator->emitError("unimplemented terminator for Phi creation");
}
@ -996,7 +1010,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
phiArgs.push_back(phiID);
for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
Value value = *(predecessors[predIndex].second + argIndex);
Value value = predecessors[predIndex].second[argIndex];
uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
<< ") value " << value << ' ');

View File

@ -286,3 +286,60 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.EntryPoint "GLCompute" @fmul_kernel
spv.ExecutionMode @fmul_kernel "LocalSize", 32, 1, 1
}
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @cond_branch_true_argument
spv.func @cond_branch_true_argument() -> () "None" {
%true = spv.Constant true
%zero = spv.Constant 0 : i32
%one = spv.Constant 1 : i32
// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}}, %{{.*}} : i32, i32), ^[[false1:.*]]
spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
// CHECK: [[true1]](%{{.*}}: i32, %{{.*}}: i32)
^true1(%arg0: i32, %arg1: i32):
spv.Return
// CHECK: [[false1]]:
^false1:
spv.Return
}
}
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @cond_branch_false_argument
spv.func @cond_branch_false_argument() -> () "None" {
%true = spv.Constant true
%zero = spv.Constant 0 : i32
%one = spv.Constant 1 : i32
// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]], ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
spv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32)
// CHECK: [[true1]]:
^true1:
spv.Return
// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
^false1(%arg0: i32, %arg1: i32):
spv.Return
}
}
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @cond_branch_true_and_false_argument
spv.func @cond_branch_true_and_false_argument() -> () "None" {
%true = spv.Constant true
%zero = spv.Constant 0 : i32
%one = spv.Constant 1 : i32
// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}} : i32), ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
spv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32)
// CHECK: [[true1]](%{{.*}}: i32):
^true1(%arg0: i32):
spv.Return
// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
^false1(%arg1: i32, %arg2: i32):
spv.Return
}
}