[MLIR][SPIRVToLLVM] Branch weights support for BranchConditional conversion

Conversion of `spv.BranchConditional` now supports branch weights
that are mapped to weights vector in `llvm.cond_br`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D84657
This commit is contained in:
George Mitenkov 2020-07-29 09:16:47 +03:00
parent 89247792c5
commit 1f4aa30a4f
2 changed files with 13 additions and 10 deletions

View File

@ -459,13 +459,18 @@ public:
LogicalResult
matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// There is no support of branch weights in LLVM dialect at the moment.
if (auto weights = op.branch_weights())
return failure();
// If branch weights exist, map them to 32-bit integer vector.
ElementsAttr branchWeights = nullptr;
if (auto weights = op.branch_weights()) {
VectorType weightType = VectorType::get(2, rewriter.getI32Type());
branchWeights =
DenseElementsAttr::get(weightType, weights.getValue().getValue());
}
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, op.condition(), op.getTrueBlock(), op.getTrueBlockArguments(),
op.getFalseBlock(), op.getFalseBlockArguments());
op, op.condition(), op.getTrueBlockArguments(),
op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
op.getFalseBlock());
return success();
}
};

View File

@ -66,16 +66,14 @@ spv.module Logical GLSL450 {
^inner_false(%arg3: i32, %arg4: i32):
spv.Return
}
}
// -----
spv.module Logical GLSL450 {
spv.func @cond_branch_with_weights(%cond: i1) -> () "None" {
// expected-error@+1 {{failed to legalize operation 'spv.BranchConditional' that was explicitly marked illegal}}
// CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2
spv.BranchConditional %cond [1, 2], ^true, ^false
// CHECK: ^bb1:
^true:
spv.Return
// CHECK: ^bb2:
^false:
spv.Return
}