[Clang][OpenMP] Enable floating-point operation for `atomic compare` series

D127041 introduced the support for `fmax` and `fmin` such that we can also reprent
`atomic compare` and `atomic compare capture` with `atomicrmw` instruction. This
patch simply lifts the limitation we set before.

Depend on D127041.

Reviewed By: ABataev

Differential Revision: https://reviews.llvm.org/D127042
This commit is contained in:
Shilei Tian 2022-07-06 13:05:00 -04:00
parent fbb51ac0ba
commit 83837a6198
5 changed files with 11497 additions and 49 deletions

View File

@ -5998,18 +5998,26 @@ static std::pair<bool, RValue> emitOMPAtomicRMW(CodeGenFunction &CGF, LValue X,
RMWOp = llvm::AtomicRMWInst::Xor;
break;
case BO_LT:
RMWOp = X.getType()->hasSignedIntegerRepresentation()
? (IsXLHSInRHSPart ? llvm::AtomicRMWInst::Min
: llvm::AtomicRMWInst::Max)
: (IsXLHSInRHSPart ? llvm::AtomicRMWInst::UMin
: llvm::AtomicRMWInst::UMax);
if (IsInteger)
RMWOp = X.getType()->hasSignedIntegerRepresentation()
? (IsXLHSInRHSPart ? llvm::AtomicRMWInst::Min
: llvm::AtomicRMWInst::Max)
: (IsXLHSInRHSPart ? llvm::AtomicRMWInst::UMin
: llvm::AtomicRMWInst::UMax);
else
RMWOp = IsXLHSInRHSPart ? llvm::AtomicRMWInst::FMin
: llvm::AtomicRMWInst::FMax;
break;
case BO_GT:
RMWOp = X.getType()->hasSignedIntegerRepresentation()
? (IsXLHSInRHSPart ? llvm::AtomicRMWInst::Max
: llvm::AtomicRMWInst::Min)
: (IsXLHSInRHSPart ? llvm::AtomicRMWInst::UMax
: llvm::AtomicRMWInst::UMin);
if (IsInteger)
RMWOp = X.getType()->hasSignedIntegerRepresentation()
? (IsXLHSInRHSPart ? llvm::AtomicRMWInst::Max
: llvm::AtomicRMWInst::Min)
: (IsXLHSInRHSPart ? llvm::AtomicRMWInst::UMax
: llvm::AtomicRMWInst::UMin);
else
RMWOp = IsXLHSInRHSPart ? llvm::AtomicRMWInst::FMax
: llvm::AtomicRMWInst::FMin;
break;
case BO_Assign:
RMWOp = llvm::AtomicRMWInst::Xchg;

View File

@ -11570,7 +11570,7 @@ protected:
bool checkType(ErrorInfoTy &ErrorInfo) const;
static bool CheckValue(const Expr *E, ErrorInfoTy &ErrorInfo,
bool ShouldBeLValue) {
bool ShouldBeLValue, bool ShouldBeInteger = false) {
if (ShouldBeLValue && !E->isLValue()) {
ErrorInfo.Error = ErrorTy::XNotLValue;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
@ -11586,8 +11586,7 @@ protected:
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
return false;
}
if (!QTy->isIntegerType()) {
if (ShouldBeInteger && !QTy->isIntegerType()) {
ErrorInfo.Error = ErrorTy::NotInteger;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
@ -11890,7 +11889,7 @@ bool OpenMPAtomicCompareCaptureChecker::checkType(ErrorInfoTy &ErrorInfo) {
if (V && !CheckValue(V, ErrorInfo, true))
return false;
if (R && !CheckValue(R, ErrorInfo, true))
if (R && !CheckValue(R, ErrorInfo, true, true))
return false;
return true;

File diff suppressed because it is too large Load Diff

View File

@ -482,23 +482,6 @@ void compare(void) {
else
d = e;
}
float fx = 0.0f;
float fd = 0.0f;
float fe = 0.0f;
// omp51-error@+5 {{the statement for 'atomic compare' must be a compound statement of form '{x = expr ordop x ? expr : x;}', '{x = x ordop expr? expr : x;}', '{x = x == e ? d : x;}', '{x = e == x ? d : x;}', or 'if(expr ordop x) {x = expr;}', 'if(x ordop expr) {x = expr;}', 'if(x == e) {x = d;}', 'if(e == x) {x = d;}' where 'x' is an lvalue expression with scalar type, 'expr', 'e', and 'd' are expressions with scalar type, and 'ordop' is one of '<' or '>'.}}
// omp51-note@+4 {{expect integer value}}
#pragma omp atomic compare
{
if (fx > fe)
fx = fe;
}
// omp51-error@+5 {{the statement for 'atomic compare' must be a compound statement of form '{x = expr ordop x ? expr : x;}', '{x = x ordop expr? expr : x;}', '{x = x == e ? d : x;}', '{x = e == x ? d : x;}', or 'if(expr ordop x) {x = expr;}', 'if(x ordop expr) {x = expr;}', 'if(x == e) {x = d;}', 'if(e == x) {x = d;}' where 'x' is an lvalue expression with scalar type, 'expr', 'e', and 'd' are expressions with scalar type, and 'ordop' is one of '<' or '>'.}}
// omp51-note@+4 {{expect integer value}}
#pragma omp atomic compare
{
if (fx == fe)
fx = fe;
}
}
void compare_capture(void) {
@ -507,6 +490,7 @@ void compare_capture(void) {
int e = 0;
int v = 0;
int r = 0;
float dr = 0.0;
// omp51-error@+3 {{the statement for 'atomic compare capture' must be a compound statement of form '{v = x; cond-up-stmt}', ''{cond-up-stmt v = x;}', '{if(x == e) {x = d;} else {v = x;}}', '{r = x == e; if(r) {x = d;}}', or '{r = x == e; if(r) {x = d;} else {v = x;}}', where 'cond-update-stmt' can have one of the following forms: 'if(expr ordop x) {x = expr;}', 'if(x ordop expr) {x = expr;}', 'if(x == e) {x = d;}', or 'if(e == x) {x = d;}' where 'x' is an lvalue expression with scalar type, 'expr', 'e', and 'd' are expressions with scalar type, and 'ordop' is one of '<' or '>'.}}
// omp51-note@+2 {{expected compound statement}}
#pragma omp atomic compare capture
@ -689,10 +673,9 @@ void compare_capture(void) {
#pragma omp atomic compare capture
{ v = x; bbar(); }
float fv;
// omp51-error@+3 {{the statement for 'atomic compare capture' must be a compound statement of form '{v = x; cond-up-stmt}', ''{cond-up-stmt v = x;}', '{if(x == e) {x = d;} else {v = x;}}', '{r = x == e; if(r) {x = d;}}', or '{r = x == e; if(r) {x = d;} else {v = x;}}', where 'cond-update-stmt' can have one of the following forms: 'if(expr ordop x) {x = expr;}', 'if(x ordop expr) {x = expr;}', 'if(x == e) {x = d;}', or 'if(e == x) {x = d;}' where 'x' is an lvalue expression with scalar type, 'expr', 'e', and 'd' are expressions with scalar type, and 'ordop' is one of '<' or '>'.}}
// omp51-note@+2 {{expect integer value}}
#pragma omp atomic compare capture
{ fv = x; if (x == e) { x = d; } }
{ dr = x == e; if (dr) { x = d; } }
}
#endif

View File

@ -4128,20 +4128,37 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
assert(X.Var->getType()->isPointerTy() &&
"OMP atomic expects a pointer to target memory");
assert((X.ElemTy->isIntegerTy() || X.ElemTy->isPointerTy()) &&
"OMP atomic compare expected a integer scalar type");
// compare capture
if (V.Var) {
assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
}
bool IsInteger = E->getType()->isIntegerTy();
if (Op == OMPAtomicCompareOp::EQ) {
AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
AtomicCmpXchgInst *Result =
Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
AtomicCmpXchgInst *Result = nullptr;
if (!IsInteger) {
unsigned Addrspace =
cast<PointerType>(X.Var->getType())->getAddressSpace();
IntegerType *IntCastTy =
IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
Value *XBCast =
Builder.CreateBitCast(X.Var, IntCastTy->getPointerTo(Addrspace));
Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
Result = Builder.CreateAtomicCmpXchg(XBCast, EBCast, DBCast, MaybeAlign(),
AO, Failure);
} else {
Result =
Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
}
if (V.Var) {
Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
if (!IsInteger)
OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
assert(OldValue->getType() == V.ElemTy &&
"OldValue and V must be of same type");
if (IsPostfixUpdate) {
@ -4215,19 +4232,29 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
// x = x <= expr ? x : expr;
AtomicRMWInst::BinOp NewOp;
if (IsXBinopExpr) {
if (X.IsSigned)
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
: AtomicRMWInst::Max;
else
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
: AtomicRMWInst::UMax;
if (IsInteger) {
if (X.IsSigned)
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
: AtomicRMWInst::Max;
else
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
: AtomicRMWInst::UMax;
} else {
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
: AtomicRMWInst::FMax;
}
} else {
if (X.IsSigned)
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
: AtomicRMWInst::Min;
else
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
: AtomicRMWInst::UMin;
if (IsInteger) {
if (X.IsSigned)
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
: AtomicRMWInst::Min;
else
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
: AtomicRMWInst::UMin;
} else {
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
: AtomicRMWInst::FMin;
}
}
AtomicRMWInst *OldValue =
@ -4245,12 +4272,18 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
case AtomicRMWInst::UMax:
Pred = CmpInst::ICMP_UGT;
break;
case AtomicRMWInst::FMax:
Pred = CmpInst::FCMP_OGT;
break;
case AtomicRMWInst::Min:
Pred = CmpInst::ICMP_SLT;
break;
case AtomicRMWInst::UMin:
Pred = CmpInst::ICMP_ULT;
break;
case AtomicRMWInst::FMin:
Pred = CmpInst::FCMP_OLT;
break;
default:
llvm_unreachable("unexpected comparison op");
}