[OMPIRBuilder] Add support for atomic compare

This patch adds the support for `atomic compare` in `OMPIRBuilder`.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D118547
This commit is contained in:
Shilei Tian 2022-02-02 15:38:59 -05:00
parent 1a3137dc84
commit 1a547a94c1
4 changed files with 157 additions and 1 deletions

View File

@ -116,6 +116,9 @@ enum class AddressSpace : unsigned {
/// \note This needs to be kept in sync with interop.h enum kmp_interop_type_t.:
enum class OMPInteropType { Unknown, Target, TargetSync };
/// Atomic compare operations. Currently OpenMP only supports ==, >, and <.
enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX };
} // end namespace omp
} // end namespace llvm

View File

@ -1198,7 +1198,7 @@ private:
const function_ref<Value *(Value *XOld, IRBuilder<> &IRB)>;
private:
enum AtomicKind { Read, Write, Update, Capture };
enum AtomicKind { Read, Write, Update, Capture, Compare };
/// Determine whether to emit flush or not
///
@ -1344,6 +1344,39 @@ public:
AtomicUpdateCallbackTy &UpdateOp, bool UpdateExpr,
bool IsPostfixUpdate, bool IsXBinopExpr);
/// Emit atomic compare for constructs: --- Only scalar data types
/// cond-update-atomic:
/// x = x ordop expr ? expr : x;
/// x = expr ordop x ? expr : x;
/// x = x == e ? d : x;
/// x = e == x ? d : x; (this one is not in the spec)
/// cond-update-stmt:
/// if (x ordop expr) { x = expr; }
/// if (expr ordop x) { x = expr; }
/// if (x == e) { x = d; }
/// if (e == x) { x = d; } (this one is not in the spec)
///
/// \param Loc The insert and source location description.
/// \param X The target atomic pointer to be updated.
/// \param E The expected value ('e') for forms that use an
/// equality comparison or an expression ('expr') for
/// forms that use 'ordop' (logically an atomic maximum or
/// minimum).
/// \param D The desired value for forms that use an equality
/// comparison. If forms that use 'ordop', it should be
/// \p nullptr.
/// \param AO Atomic ordering of the generated atomic instructions.
/// \param OP Atomic compare operation. It can only be ==, <, or >.
/// \param IsXBinopExpr True if the conditional statement is in the form where
/// x is on LHS. It only matters for < or >.
///
/// \return Insertion point after generated atomic capture IR.
InsertPointTy createAtomicCompare(const LocationDescription &Loc,
AtomicOpValue &X, Value *E, Value *D,
AtomicOrdering AO,
omp::OMPAtomicCompareOp Op,
bool IsXBinopExpr);
/// Create the control flow structure of a canonical OpenMP loop.
///
/// The emitted loop will be disconnected, i.e. no edge to the loop's

View File

@ -3171,6 +3171,7 @@ bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
}
break;
case Write:
case Compare:
case Update:
if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
AO == AtomicOrdering::SequentiallyConsistent) {
@ -3472,6 +3473,68 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
return Builder.saveIP();
}
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
const LocationDescription &Loc, AtomicOpValue &X, Value *E, Value *D,
AtomicOrdering AO, OMPAtomicCompareOp Op, bool IsXBinopExpr) {
if (!updateToLocation(Loc))
return Loc.IP;
assert(X.Var->getType()->isPointerTy() &&
"OMP atomic expects a pointer to target memory");
assert((X.ElemTy->isFloatingPointTy() || X.ElemTy->isIntegerTy() ||
X.ElemTy->isPointerTy()) &&
"OMP atomic compare expected a scalar type");
if (Op == OMPAtomicCompareOp::EQ) {
unsigned Addrspace = cast<PointerType>(X.Var->getType())->getAddressSpace();
IntegerType *IntCastTy =
IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
Value *XAddr =
X.ElemTy->isIntegerTy()
? X.Var
: Builder.CreateBitCast(X.Var, IntCastTy->getPointerTo(Addrspace));
AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
// We don't need the result for now.
(void)Builder.CreateAtomicCmpXchg(XAddr, E, D, MaybeAlign(), AO, Failure);
} else {
assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
"Op should be either max or min at this point");
assert(X.ElemTy->isIntegerTy() &&
"max and min operators only support integer type");
// Reverse the ordop as the OpenMP forms are different from LLVM forms.
// Let's take max as example.
// OpenMP form:
// x = x > expr ? expr : x;
// LLVM form:
// *ptr = *ptr > val ? *ptr : val;
// We need to transform to LLVM form.
// x = x <= expr ? x : expr;
AtomicRMWInst::BinOp NewOp;
if (IsXBinopExpr) {
if (X.IsSigned)
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
: AtomicRMWInst::Max;
else
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
: AtomicRMWInst::UMax;
} else {
if (X.IsSigned)
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
: AtomicRMWInst::Min;
else
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
: AtomicRMWInst::UMin;
}
// We dont' need the result for now.
(void)Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
}
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
return Builder.saveIP();
}
GlobalVariable *
OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
std::string VarName) {

View File

@ -3031,6 +3031,63 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicCapture) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
TEST_F(OpenMPIRBuilderTest, OMPAtomicCompare) {
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> Builder(BB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
LLVMContext &Ctx = M->getContext();
IntegerType *Int32 = Type::getInt32Ty(Ctx);
AllocaInst *XVal = Builder.CreateAlloca(Int32);
XVal->setName("x");
StoreInst *Init =
Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), XVal);
OpenMPIRBuilder::AtomicOpValue XSigned = {XVal, Int32, true, false};
OpenMPIRBuilder::AtomicOpValue XUnsigned = {XVal, Int32, false, false};
AtomicOrdering AO = AtomicOrdering::Monotonic;
ConstantInt *Expr = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
ConstantInt *D = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
OMPAtomicCompareOp OpMax = OMPAtomicCompareOp::MAX;
OMPAtomicCompareOp OpEQ = OMPAtomicCompareOp::EQ;
Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XSigned, Expr,
nullptr, AO, OpMax, true));
Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XUnsigned, Expr,
nullptr, AO, OpMax, false));
Builder.restoreIP(OMPBuilder.createAtomicCompare(Builder, XSigned, Expr, D,
AO, OpEQ, true));
BasicBlock *EntryBB = BB;
EXPECT_EQ(EntryBB->getParent()->size(), 1U);
EXPECT_EQ(EntryBB->size(), 5U);
AtomicRMWInst *ARWM1 = dyn_cast<AtomicRMWInst>(Init->getNextNode());
EXPECT_NE(ARWM1, nullptr);
EXPECT_EQ(ARWM1->getPointerOperand(), XVal);
EXPECT_EQ(ARWM1->getValOperand(), Expr);
EXPECT_EQ(ARWM1->getOperation(), AtomicRMWInst::Min);
AtomicRMWInst *ARWM2 = dyn_cast<AtomicRMWInst>(ARWM1->getNextNode());
EXPECT_NE(ARWM2, nullptr);
EXPECT_EQ(ARWM2->getPointerOperand(), XVal);
EXPECT_EQ(ARWM2->getValOperand(), Expr);
EXPECT_EQ(ARWM2->getOperation(), AtomicRMWInst::UMax);
AtomicCmpXchgInst *AXCHG = dyn_cast<AtomicCmpXchgInst>(ARWM2->getNextNode());
EXPECT_NE(AXCHG, nullptr);
EXPECT_EQ(AXCHG->getPointerOperand(), XVal);
EXPECT_EQ(AXCHG->getCompareOperand(), Expr);
EXPECT_EQ(AXCHG->getNewValOperand(), D);
Builder.CreateRetVoid();
OMPBuilder.finalize();
EXPECT_FALSE(verifyModule(*M, &errs()));
}
/// Returns the single instruction of InstTy type in BB that uses the value V.
/// If there is more than one such instruction, returns null.
template <typename InstTy>