forked from OSchip/llvm-project
[MLIR][LLVMDialect] Added branch weights attribute to CondBrOp
This patch introduces branch weights metadata to `llvm.cond_br` op in LLVM Dialect. It is modelled as optional `ElementsAttr`, for example: ``` llvm.cond_br %cond weights(dense<[1, 3]> : vector<2xi32>), ^bb1, ^bb2 ``` When exporting to proper LLVM, this attribute is transformed into metadata node. The test for metadata creation is added to `../Target/llvmir.mlir`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D83658
This commit is contained in:
parent
205e8b7e89
commit
99d03f0391
|
@ -514,21 +514,29 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
|
|||
NoSideEffect]> {
|
||||
let arguments = (ins LLVMI1:$condition,
|
||||
Variadic<LLVM_Type>:$trueDestOperands,
|
||||
Variadic<LLVM_Type>:$falseDestOperands);
|
||||
Variadic<LLVM_Type>:$falseDestOperands,
|
||||
OptionalAttr<ElementsAttr>:$branch_weights);
|
||||
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
|
||||
let assemblyFormat = [{
|
||||
$condition `,`
|
||||
$condition ( `weights` `(` $branch_weights^ `)` )? `,`
|
||||
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
|
||||
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
|
||||
attr-dict
|
||||
}];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value condition,"
|
||||
"Block *trueDest, ValueRange trueOperands,"
|
||||
"Block *falseDest, ValueRange falseOperands", [{
|
||||
build(builder, result, condition, trueOperands, falseOperands, trueDest,
|
||||
falseDest);
|
||||
"OpBuilder &builder, OperationState &result, Value condition,"
|
||||
"Block *trueDest, ValueRange trueOperands,"
|
||||
"Block *falseDest, ValueRange falseOperands,"
|
||||
"Optional<std::pair<uint32_t, uint32_t>> weights = {}", [{
|
||||
ElementsAttr weightsAttr;
|
||||
if (weights) {
|
||||
weightsAttr =
|
||||
builder.getI32VectorAttr({static_cast<int32_t>(weights->first),
|
||||
static_cast<int32_t>(weights->second)});
|
||||
}
|
||||
build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
|
||||
trueDest, falseDest);
|
||||
}]>, OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value condition,"
|
||||
"Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/MDBuilder.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
|
@ -594,9 +595,22 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
|
|||
return success();
|
||||
}
|
||||
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
|
||||
auto weights = condbrOp.branch_weights();
|
||||
llvm::MDNode *branchWeights = nullptr;
|
||||
if (weights) {
|
||||
// Map weight attributes to LLVM metadata.
|
||||
auto trueWeight =
|
||||
weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
|
||||
auto falseWeight =
|
||||
weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
|
||||
branchWeights =
|
||||
llvm::MDBuilder(llvmModule->getContext())
|
||||
.createBranchWeights(static_cast<uint32_t>(trueWeight),
|
||||
static_cast<uint32_t>(falseWeight));
|
||||
}
|
||||
builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
|
||||
blockMapping[condbrOp.getSuccessor(0)],
|
||||
blockMapping[condbrOp.getSuccessor(1)]);
|
||||
blockMapping[condbrOp.getSuccessor(1)], branchWeights);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -1252,3 +1252,17 @@ llvm.mlir.global internal constant @taker_of_address() : !llvm<"void()*"> {
|
|||
%0 = llvm.mlir.addressof @address_taken : !llvm<"void()*">
|
||||
llvm.return %0 : !llvm<"void()*">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that branch weight attributes are exported properly as metadata.
|
||||
llvm.func @cond_br_weights(%cond : !llvm.i1, %arg0 : !llvm.i32, %arg1 : !llvm.i32) -> !llvm.i32 {
|
||||
// CHECK: !prof ![[NODE:[0-9]+]]
|
||||
llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2
|
||||
^bb1: // pred: ^bb0
|
||||
llvm.return %arg0 : !llvm.i32
|
||||
^bb2: // pred: ^bb0
|
||||
llvm.return %arg1 : !llvm.i32
|
||||
}
|
||||
|
||||
// CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10}
|
||||
|
|
Loading…
Reference in New Issue