forked from OSchip/llvm-project
[OpenMP][IRBuilder] Add final clause to task
This patch adds final clause to OpenMP IR Builder. Reviewed By: Meinersbur Differential Revision: https://reviews.llvm.org/D126626
This commit is contained in:
parent
13558334f3
commit
f62baddac0
|
@ -624,9 +624,11 @@ public:
|
||||||
/// \param AllocaIP The insertion point to be used for alloca instructions.
|
/// \param AllocaIP The insertion point to be used for alloca instructions.
|
||||||
/// \param BodyGenCB Callback that will generate the region code.
|
/// \param BodyGenCB Callback that will generate the region code.
|
||||||
/// \param Tied True if the task is tied, false if the task is untied.
|
/// \param Tied True if the task is tied, false if the task is untied.
|
||||||
|
/// \param Final i1 value which is `true` if the task is final, `false` if the
|
||||||
|
/// task is not final.
|
||||||
InsertPointTy createTask(const LocationDescription &Loc,
|
InsertPointTy createTask(const LocationDescription &Loc,
|
||||||
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
|
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
|
||||||
bool Tied = true);
|
bool Tied = true, Value *Final = nullptr);
|
||||||
|
|
||||||
/// Functions used to generate reductions. Such functions take two Values
|
/// Functions used to generate reductions. Such functions take two Values
|
||||||
/// representing LHS and RHS of the reduction, respectively, and a reference
|
/// representing LHS and RHS of the reduction, respectively, and a reference
|
||||||
|
|
|
@ -1256,7 +1256,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
|
||||||
OpenMPIRBuilder::InsertPointTy
|
OpenMPIRBuilder::InsertPointTy
|
||||||
OpenMPIRBuilder::createTask(const LocationDescription &Loc,
|
OpenMPIRBuilder::createTask(const LocationDescription &Loc,
|
||||||
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
|
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
|
||||||
bool Tied) {
|
bool Tied, Value *Final) {
|
||||||
if (!updateToLocation(Loc))
|
if (!updateToLocation(Loc))
|
||||||
return InsertPointTy();
|
return InsertPointTy();
|
||||||
|
|
||||||
|
@ -1285,7 +1285,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
|
||||||
OI.EntryBB = TaskAllocaBB;
|
OI.EntryBB = TaskAllocaBB;
|
||||||
OI.OuterAllocaBB = AllocaIP.getBlock();
|
OI.OuterAllocaBB = AllocaIP.getBlock();
|
||||||
OI.ExitBB = TaskExitBB;
|
OI.ExitBB = TaskExitBB;
|
||||||
OI.PostOutlineCB = [this, &Loc, Tied](Function &OutlinedFn) {
|
OI.PostOutlineCB = [this, &Loc, Tied, Final](Function &OutlinedFn) {
|
||||||
// The input IR here looks like the following-
|
// The input IR here looks like the following-
|
||||||
// ```
|
// ```
|
||||||
// func @current_fn() {
|
// func @current_fn() {
|
||||||
|
@ -1330,10 +1330,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
|
||||||
Value *ThreadID = getOrCreateThreadID(Ident);
|
Value *ThreadID = getOrCreateThreadID(Ident);
|
||||||
|
|
||||||
// Argument - `flags`
|
// Argument - `flags`
|
||||||
// If task is tied, then (Flags & 1) == 1.
|
// Task is tied iff (Flags & 1) == 1.
|
||||||
// If task is untied, then (Flags & 1) == 0.
|
// Task is untied iff (Flags & 1) == 0.
|
||||||
|
// Task is final iff (Flags & 2) == 2.
|
||||||
|
// Task is not final iff (Flags & 2) == 0.
|
||||||
// TODO: Handle the other flags.
|
// TODO: Handle the other flags.
|
||||||
Value *Flags = Builder.getInt32(Tied);
|
Value *Flags = Builder.getInt32(Tied);
|
||||||
|
if (Final) {
|
||||||
|
Value *FinalFlag =
|
||||||
|
Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
|
||||||
|
Flags = Builder.CreateOr(FinalFlag, Flags);
|
||||||
|
}
|
||||||
|
|
||||||
// Argument - `sizeof_kmp_task_t` (TaskSize)
|
// Argument - `sizeof_kmp_task_t` (TaskSize)
|
||||||
// Tasksize refers to the size in bytes of kmp_task_t data structure
|
// Tasksize refers to the size in bytes of kmp_task_t data structure
|
||||||
|
|
|
@ -4832,4 +4832,57 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
|
||||||
EXPECT_FALSE(verifyModule(*M, &errs()));
|
EXPECT_FALSE(verifyModule(*M, &errs()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
|
||||||
|
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
|
||||||
|
OpenMPIRBuilder OMPBuilder(*M);
|
||||||
|
OMPBuilder.initialize();
|
||||||
|
F->setName("func");
|
||||||
|
IRBuilder<> Builder(BB);
|
||||||
|
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
|
||||||
|
IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
|
||||||
|
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
|
||||||
|
Builder.SetInsertPoint(BodyBB);
|
||||||
|
Value *Final = Builder.CreateICmp(
|
||||||
|
CmpInst::Predicate::ICMP_EQ, F->getArg(0),
|
||||||
|
ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
|
||||||
|
OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
|
||||||
|
Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
|
||||||
|
/*Tied=*/false, Final));
|
||||||
|
OMPBuilder.finalize();
|
||||||
|
Builder.CreateRetVoid();
|
||||||
|
|
||||||
|
// Check for the `Tied` argument
|
||||||
|
CallInst *TaskAllocCall = dyn_cast<CallInst>(
|
||||||
|
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
|
||||||
|
->user_back());
|
||||||
|
ASSERT_NE(TaskAllocCall, nullptr);
|
||||||
|
BinaryOperator *OrInst =
|
||||||
|
dyn_cast<BinaryOperator>(TaskAllocCall->getArgOperand(2));
|
||||||
|
ASSERT_NE(OrInst, nullptr);
|
||||||
|
EXPECT_EQ(OrInst->getOpcode(), BinaryOperator::BinaryOps::Or);
|
||||||
|
|
||||||
|
// One of the arguments to `or` instruction is the tied flag, which is equal
|
||||||
|
// to zero.
|
||||||
|
EXPECT_TRUE(any_of(OrInst->operands(), [](Value *op) {
|
||||||
|
if (ConstantInt *TiedValue = dyn_cast<ConstantInt>(op))
|
||||||
|
return TiedValue->getSExtValue() == 0;
|
||||||
|
return false;
|
||||||
|
}));
|
||||||
|
|
||||||
|
// One of the arguments to `or` instruction is the final condition.
|
||||||
|
EXPECT_TRUE(any_of(OrInst->operands(), [Final](Value *op) {
|
||||||
|
if (SelectInst *Select = dyn_cast<SelectInst>(op)) {
|
||||||
|
ConstantInt *TrueValue = dyn_cast<ConstantInt>(Select->getTrueValue());
|
||||||
|
ConstantInt *FalseValue = dyn_cast<ConstantInt>(Select->getFalseValue());
|
||||||
|
if (!TrueValue || !FalseValue)
|
||||||
|
return false;
|
||||||
|
return Select->getCondition() == Final &&
|
||||||
|
TrueValue->getSExtValue() == 2 && FalseValue->getSExtValue() == 0;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}));
|
||||||
|
|
||||||
|
EXPECT_FALSE(verifyModule(*M, &errs()));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
Loading…
Reference in New Issue