[OpenMP][IRBuilder] Add final clause to task

This patch adds final clause to OpenMP IR Builder.

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D126626
This commit is contained in:
Shraiysh Vaishay 2022-06-10 23:07:18 +05:30
parent 13558334f3
commit f62baddac0
3 changed files with 67 additions and 5 deletions

View File

@ -624,9 +624,11 @@ public:
/// \param AllocaIP The insertion point to be used for alloca instructions.
/// \param BodyGenCB Callback that will generate the region code.
/// \param Tied True if the task is tied, false if the task is untied.
/// \param Final i1 value which is `true` if the task is final, `false` if the
/// task is not final.
InsertPointTy createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied = true);
bool Tied = true, Value *Final = nullptr);
/// Functions used to generate reductions. Such functions take two Values
/// representing LHS and RHS of the reduction, respectively, and a reference

View File

@ -1256,7 +1256,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied) {
bool Tied, Value *Final) {
if (!updateToLocation(Loc))
return InsertPointTy();
@ -1285,7 +1285,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
OI.EntryBB = TaskAllocaBB;
OI.OuterAllocaBB = AllocaIP.getBlock();
OI.ExitBB = TaskExitBB;
OI.PostOutlineCB = [this, &Loc, Tied](Function &OutlinedFn) {
OI.PostOutlineCB = [this, &Loc, Tied, Final](Function &OutlinedFn) {
// The input IR here looks like the following-
// ```
// func @current_fn() {
@ -1330,10 +1330,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
Value *ThreadID = getOrCreateThreadID(Ident);
// Argument - `flags`
// If task is tied, then (Flags & 1) == 1.
// If task is untied, then (Flags & 1) == 0.
// Task is tied iff (Flags & 1) == 1.
// Task is untied iff (Flags & 1) == 0.
// Task is final iff (Flags & 2) == 2.
// Task is not final iff (Flags & 2) == 0.
// TODO: Handle the other flags.
Value *Flags = Builder.getInt32(Tied);
if (Final) {
Value *FinalFlag =
Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
Flags = Builder.CreateOr(FinalFlag, Flags);
}
// Argument - `sizeof_kmp_task_t` (TaskSize)
// Tasksize refers to the size in bytes of kmp_task_t data structure

View File

@ -4832,4 +4832,57 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> Builder(BB);
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
Builder.SetInsertPoint(BodyBB);
Value *Final = Builder.CreateICmp(
CmpInst::Predicate::ICMP_EQ, F->getArg(0),
ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
/*Tied=*/false, Final));
OMPBuilder.finalize();
Builder.CreateRetVoid();
// Check for the `Tied` argument
CallInst *TaskAllocCall = dyn_cast<CallInst>(
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
->user_back());
ASSERT_NE(TaskAllocCall, nullptr);
BinaryOperator *OrInst =
dyn_cast<BinaryOperator>(TaskAllocCall->getArgOperand(2));
ASSERT_NE(OrInst, nullptr);
EXPECT_EQ(OrInst->getOpcode(), BinaryOperator::BinaryOps::Or);
// One of the arguments to `or` instruction is the tied flag, which is equal
// to zero.
EXPECT_TRUE(any_of(OrInst->operands(), [](Value *op) {
if (ConstantInt *TiedValue = dyn_cast<ConstantInt>(op))
return TiedValue->getSExtValue() == 0;
return false;
}));
// One of the arguments to `or` instruction is the final condition.
EXPECT_TRUE(any_of(OrInst->operands(), [Final](Value *op) {
if (SelectInst *Select = dyn_cast<SelectInst>(op)) {
ConstantInt *TrueValue = dyn_cast<ConstantInt>(Select->getTrueValue());
ConstantInt *FalseValue = dyn_cast<ConstantInt>(Select->getFalseValue());
if (!TrueValue || !FalseValue)
return false;
return Select->getCondition() == Final &&
TrueValue->getSExtValue() == 2 && FalseValue->getSExtValue() == 0;
}
return false;
}));
EXPECT_FALSE(verifyModule(*M, &errs()));
}
} // namespace