[MLIR][LLVMDialect] Added branch weights attribute to CondBrOp

This patch introduces branch weights metadata to `llvm.cond_br` op in
LLVM Dialect. It is modelled as optional `ElementsAttr`, for example:
```
llvm.cond_br %cond weights(dense<[1, 3]> : vector<2xi32>), ^bb1, ^bb2
```
When exporting to proper LLVM, this attribute is transformed into metadata
node. The test for metadata creation is added to `../Target/llvmir.mlir`.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D83658
This commit is contained in:
George Mitenkov 2020-07-24 09:41:22 +03:00
parent 205e8b7e89
commit 99d03f0391
3 changed files with 44 additions and 8 deletions

View File

@ -514,21 +514,29 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
NoSideEffect]> {
let arguments = (ins LLVMI1:$condition,
Variadic<LLVM_Type>:$trueDestOperands,
Variadic<LLVM_Type>:$falseDestOperands);
Variadic<LLVM_Type>:$falseDestOperands,
OptionalAttr<ElementsAttr>:$branch_weights);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let assemblyFormat = [{
$condition `,`
$condition ( `weights` `(` $branch_weights^ `)` )? `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value condition,"
"Block *trueDest, ValueRange trueOperands,"
"Block *falseDest, ValueRange falseOperands", [{
build(builder, result, condition, trueOperands, falseOperands, trueDest,
falseDest);
"OpBuilder &builder, OperationState &result, Value condition,"
"Block *trueDest, ValueRange trueOperands,"
"Block *falseDest, ValueRange falseOperands,"
"Optional<std::pair<uint32_t, uint32_t>> weights = {}", [{
ElementsAttr weightsAttr;
if (weights) {
weightsAttr =
builder.getI32VectorAttr({static_cast<int32_t>(weights->first),
static_cast<int32_t>(weights->second)});
}
build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
trueDest, falseDest);
}]>, OpBuilder<
"OpBuilder &builder, OperationState &result, Value condition,"
"Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{

View File

@ -30,6 +30,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
@ -594,9 +595,22 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
auto weights = condbrOp.branch_weights();
llvm::MDNode *branchWeights = nullptr;
if (weights) {
// Map weight attributes to LLVM metadata.
auto trueWeight =
weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
auto falseWeight =
weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
branchWeights =
llvm::MDBuilder(llvmModule->getContext())
.createBranchWeights(static_cast<uint32_t>(trueWeight),
static_cast<uint32_t>(falseWeight));
}
builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
blockMapping[condbrOp.getSuccessor(0)],
blockMapping[condbrOp.getSuccessor(1)]);
blockMapping[condbrOp.getSuccessor(1)], branchWeights);
return success();
}

View File

@ -1252,3 +1252,17 @@ llvm.mlir.global internal constant @taker_of_address() : !llvm<"void()*"> {
%0 = llvm.mlir.addressof @address_taken : !llvm<"void()*">
llvm.return %0 : !llvm<"void()*">
}
// -----
// Check that branch weight attributes are exported properly as metadata.
llvm.func @cond_br_weights(%cond : !llvm.i1, %arg0 : !llvm.i32, %arg1 : !llvm.i32) -> !llvm.i32 {
// CHECK: !prof ![[NODE:[0-9]+]]
llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2
^bb1: // pred: ^bb0
llvm.return %arg0 : !llvm.i32
^bb2: // pred: ^bb0
llvm.return %arg1 : !llvm.i32
}
// CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10}