[MLIR][OpenMP] Added omp.atomic.update

This patch supports the atomic construct (update) following section 2.17.7 of OpenMP 5.0 standard. Also added tests and verifier for the same.

Reviewed By: kiranchandramohan, peixin

Differential Revision: https://reviews.llvm.org/D112982
This commit is contained in:
Shraiysh Vaishay 2021-12-09 11:05:55 +05:30
parent 120d44d1a0
commit d82c1f4e4b
6 changed files with 239 additions and 20 deletions

View File

@ -1166,9 +1166,9 @@ private:
/// \param UpdateOp Code generator for complex expressions that cannot be
/// expressed through atomicrmw instruction.
/// \param VolatileX true if \a X volatile?
/// \param IsXLHSInRHSPart true if \a X is Left H.S. in Right H.S. part of
/// the update expression, false otherwise.
/// (e.g. true for X = X BinOp Expr)
/// \param IsXBinopExpr true if \a X is Left H.S. in Right H.S. part of the
/// update expression, false otherwise.
/// (e.g. true for X = X BinOp Expr)
///
/// \returns A pair of the old value of X before the update, and the value
/// used for the update.
@ -1177,7 +1177,7 @@ private:
AtomicRMWInst::BinOp RMWOp,
AtomicUpdateCallbackTy &UpdateOp,
bool VolatileX,
bool IsXLHSInRHSPart);
bool IsXBinopExpr);
/// Emit the binary op. described by \p RMWOp, using \p Src1 and \p Src2 .
///
@ -1235,9 +1235,9 @@ public:
/// atomic will be generated.
/// \param UpdateOp Code generator for complex expressions that cannot be
/// expressed through atomicrmw instruction.
/// \param IsXLHSInRHSPart true if \a X is Left H.S. in Right H.S. part of
/// the update expression, false otherwise.
/// (e.g. true for X = X BinOp Expr)
/// \param IsXBinopExpr true if \a X is Left H.S. in Right H.S. part of the
/// update expression, false otherwise.
/// (e.g. true for X = X BinOp Expr)
///
/// \return Insertion point after generated atomic update IR.
InsertPointTy createAtomicUpdate(const LocationDescription &Loc,
@ -1245,7 +1245,7 @@ public:
Value *Expr, AtomicOrdering AO,
AtomicRMWInst::BinOp RMWOp,
AtomicUpdateCallbackTy &UpdateOp,
bool IsXLHSInRHSPart);
bool IsXBinopExpr);
/// Emit atomic update for constructs: --- Only Scalar data types
/// V = X; X = X BinOp Expr ,
@ -1269,9 +1269,9 @@ public:
/// expressed through atomicrmw instruction.
/// \param UpdateExpr true if X is an in place update of the form
/// X = X BinOp Expr or X = Expr BinOp X
/// \param IsXLHSInRHSPart true if X is Left H.S. in Right H.S. part of the
/// update expression, false otherwise.
/// (e.g. true for X = X BinOp Expr)
/// \param IsXBinopExpr true if X is Left H.S. in Right H.S. part of the
/// update expression, false otherwise.
/// (e.g. true for X = X BinOp Expr)
/// \param IsPostfixUpdate true if original value of 'x' must be stored in
/// 'v', not an updated one.
///
@ -1281,7 +1281,7 @@ public:
AtomicOpValue &X, AtomicOpValue &V, Value *Expr,
AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
AtomicUpdateCallbackTy &UpdateOp, bool UpdateExpr,
bool IsPostfixUpdate, bool IsXLHSInRHSPart);
bool IsPostfixUpdate, bool IsXBinopExpr);
/// Create the control flow structure of a canonical OpenMP loop.
///

View File

@ -3080,7 +3080,7 @@ OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
const LocationDescription &Loc, Instruction *AllocIP, AtomicOpValue &X,
Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
AtomicUpdateCallbackTy &UpdateOp, bool IsXLHSInRHSPart) {
AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
if (!updateToLocation(Loc))
return Loc.IP;
@ -3098,7 +3098,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
});
emitAtomicUpdate(AllocIP, X.Var, Expr, AO, RMWOp, UpdateOp, X.IsVolatile,
IsXLHSInRHSPart);
IsXBinopExpr);
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
return Builder.saveIP();
}
@ -3135,13 +3135,13 @@ std::pair<Value *, Value *>
OpenMPIRBuilder::emitAtomicUpdate(Instruction *AllocIP, Value *X, Value *Expr,
AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
AtomicUpdateCallbackTy &UpdateOp,
bool VolatileX, bool IsXLHSInRHSPart) {
bool VolatileX, bool IsXBinopExpr) {
Type *XElemTy = X->getType()->getPointerElementType();
bool DoCmpExch =
((RMWOp == AtomicRMWInst::BAD_BINOP) || (RMWOp == AtomicRMWInst::FAdd)) ||
(RMWOp == AtomicRMWInst::FSub) ||
(RMWOp == AtomicRMWInst::Sub && !IsXLHSInRHSPart);
(RMWOp == AtomicRMWInst::Sub && !IsXBinopExpr);
std::pair<Value *, Value *> Res;
if (XElemTy->isIntegerTy() && !DoCmpExch) {
@ -3233,7 +3233,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
const LocationDescription &Loc, Instruction *AllocIP, AtomicOpValue &X,
AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
bool UpdateExpr, bool IsPostfixUpdate, bool IsXLHSInRHSPart) {
bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
if (!updateToLocation(Loc))
return Loc.IP;
@ -3252,9 +3252,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
// If UpdateExpr is 'x' updated with some `expr` not based on 'x',
// 'x' is simply atomically rewritten with 'expr'.
AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
std::pair<Value *, Value *> Result =
emitAtomicUpdate(AllocIP, X.Var, Expr, AO, AtomicOp, UpdateOp,
X.IsVolatile, IsXLHSInRHSPart);
std::pair<Value *, Value *> Result = emitAtomicUpdate(
AllocIP, X.Var, Expr, AO, AtomicOp, UpdateOp, X.IsVolatile, IsXBinopExpr);
Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);

View File

@ -595,6 +595,45 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write"> {
let verifier = [{ return verifyAtomicWriteOp(*this); }];
}
// TODO: autogenerate from OMP.td in future if possible.
def ATOMIC_BINOP_KIND_ADD : I64EnumAttrCase<"ADD", 0>;
def ATOMIC_BINOP_KIND_MUL : I64EnumAttrCase<"MUL", 1>;
def ATOMIC_BINOP_KIND_SUB : I64EnumAttrCase<"SUB", 2>;
def ATOMIC_BINOP_KIND_DIV : I64EnumAttrCase<"DIV", 3>;
def ATOMIC_BINOP_KIND_AND : I64EnumAttrCase<"AND", 4>;
def ATOMIC_BINOP_KIND_OR : I64EnumAttrCase<"OR", 5>;
def ATOMIC_BINOP_KIND_XOR : I64EnumAttrCase<"XOR", 6>;
def ATOMIC_BINOP_KIND_SHIFT_RIGHT : I64EnumAttrCase<"SHIFTR", 7>;
def ATOMIC_BINOP_KIND_SHIFT_LEFT : I64EnumAttrCase<"SHIFTL", 8>;
def ATOMIC_BINOP_KIND_MAX : I64EnumAttrCase<"MAX", 9>;
def ATOMIC_BINOP_KIND_MIN : I64EnumAttrCase<"MIN", 10>;
def ATOMIC_BINOP_KIND_EQV : I64EnumAttrCase<"EQV", 11>;
def ATOMIC_BINOP_KIND_NEQV : I64EnumAttrCase<"NEQV", 12>;
def AtomicBinOpKindAttr : I64EnumAttr<
"AtomicBinOpKind", "BinOp for Atomic Updates",
[ATOMIC_BINOP_KIND_ADD, ATOMIC_BINOP_KIND_MUL, ATOMIC_BINOP_KIND_SUB,
ATOMIC_BINOP_KIND_DIV, ATOMIC_BINOP_KIND_AND, ATOMIC_BINOP_KIND_OR,
ATOMIC_BINOP_KIND_XOR, ATOMIC_BINOP_KIND_SHIFT_RIGHT,
ATOMIC_BINOP_KIND_SHIFT_LEFT, ATOMIC_BINOP_KIND_MAX,
ATOMIC_BINOP_KIND_MIN, ATOMIC_BINOP_KIND_EQV, ATOMIC_BINOP_KIND_NEQV]> {
let cppNamespace = "::mlir::omp";
let stringToSymbolFnName = "AtomicBinOpKindToEnum";
let symbolToStringFnName = "AtomicBinOpKindToString";
}
def AtomicUpdateOp : OpenMP_Op<"atomic.update"> {
let arguments = (ins OpenMP_PointerLikeType:$x,
AnyType:$expr,
UnitAttr:$isXBinopExpr,
AtomicBinOpKindAttr:$binop,
DefaultValuedAttr<I64Attr, "0">:$hint,
OptionalAttr<MemoryOrderKind>:$memory_order);
let parser = [{ return parseAtomicUpdateOp(parser, result); }];
let printer = [{ return printAtomicUpdateOp(p, *this); }];
let verifier = [{ return verifyAtomicUpdateOp(*this); }];
}
//===----------------------------------------------------------------------===//
// 2.19.5.7 declare reduction Directive
//===----------------------------------------------------------------------===//

View File

@ -1423,5 +1423,84 @@ static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
return verifySynchronizationHint(op, op.hint());
}
//===----------------------------------------------------------------------===//
// AtomicUpdateOp
//===----------------------------------------------------------------------===//
/// Parser for AtomicUpdateOp
///
/// operation ::= `omp.atomic.update` atomic-clause-list region
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
SmallVector<int> segments;
OpAsmParser::OperandType x, y, z;
Type xType, exprType;
StringRef binOp;
// x = y `op` z : xtype, exprtype
if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) ||
parser.parseKeyword(&binOp) || parser.parseOperand(z) ||
parseClauses(parser, result, clauses, segments) || parser.parseColon() ||
parser.parseType(xType) || parser.parseComma() ||
parser.parseType(exprType) ||
parser.resolveOperand(x, xType, result.operands)) {
return failure();
}
auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper());
if (!binOpEnum)
return parser.emitError(parser.getNameLoc())
<< "invalid atomic bin op in atomic update\n";
auto attr =
parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue());
result.addAttribute("binop", attr);
OpAsmParser::OperandType expr;
if (x.name == y.name && x.number == y.number) {
expr = z;
result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr());
} else if (x.name == z.name && x.number == z.number) {
expr = y;
} else {
return parser.emitError(parser.getNameLoc())
<< "atomic update variable " << x.name
<< " not found in the RHS of the assignment statement in an"
" atomic.update operation";
}
return parser.resolveOperand(expr, exprType, result.operands);
}
/// Printer for AtomicUpdateOp
static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) {
p << " " << op.x() << " = ";
Value y, z;
if (op.isXBinopExpr()) {
y = op.x();
z = op.expr();
} else {
y = op.expr();
z = op.x();
}
p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z
<< " ";
if (op.memory_order())
p << "memory_order(" << op.memory_order() << ") ";
if (op.hintAttr())
printSynchronizationHint(p, op, op.hintAttr());
p << ": " << op.x().getType() << ", " << op.expr().getType();
}
/// Verifier for AtomicUpdateOp
static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) {
if (op.memory_order()) {
StringRef memoryOrder = op.memory_order().getValue();
if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
return op.emitError(
"memory-order must not be acq_rel or acquire for atomic updates");
}
return success();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"

View File

@ -601,6 +601,47 @@ func @omp_atomic_write6(%addr : memref<i32>, %val : i32) {
// -----
func @omp_atomic_update1(%x: memref<i32>, %expr: i32, %foo: memref<i32>) {
// expected-error @below {{atomic update variable %x not found in the RHS of the assignment statement in an atomic.update operation}}
omp.atomic.update %x = %foo add %expr : memref<i32>, i32
return
}
// -----
func @omp_atomic_update2(%x: memref<i32>, %expr: i32) {
// expected-error @below {{invalid atomic bin op in atomic update}}
omp.atomic.update %x = %x invalid %expr : memref<i32>, i32
return
}
// -----
func @omp_atomic_update3(%x: memref<i32>, %expr: i32) {
// expected-error @below {{memory-order must not be acq_rel or acquire for atomic updates}}
omp.atomic.update %x = %x add %expr memory_order(acq_rel) : memref<i32>, i32
return
}
// -----
func @omp_atomic_update4(%x: memref<i32>, %expr: i32) {
// expected-error @below {{memory-order must not be acq_rel or acquire for atomic updates}}
omp.atomic.update %x = %x add %expr memory_order(acquire) : memref<i32>, i32
return
}
// -----
// expected-note @below {{prior use here}}
func @omp_atomic_update5(%x: memref<i32>, %expr: i32) {
// expected-error @below {{use of value '%x' expects different type than prior uses: 'i32' vs 'memref<i32>'}}
omp.atomic.update %x = %x add %expr : i32, memref<i32>
return
}
// -----
func @omp_sections(%data_var1 : memref<i32>, %data_var2 : memref<i32>, %data_var3 : memref<i32>) -> () {
// expected-error @below {{operand used in both private and firstprivate clauses}}
omp.sections private(%data_var1 : memref<i32>) firstprivate(%data_var1 : memref<i32>) {

View File

@ -524,6 +524,67 @@ func @omp_atomic_write(%addr : memref<i32>, %val : i32) {
return
}
// CHECK-LABEL: omp_atomic_update
// CHECK-SAME: (%[[X:.*]]: memref<i32>, %[[EXPR:.*]]: i32, %[[XBOOL:.*]]: memref<i1>, %[[EXPRBOOL:.*]]: i1)
func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, %exprBool : i1) {
// CHECK: omp.atomic.update %[[X]] = %[[X]] add %[[EXPR]] : memref<i32>, i32
omp.atomic.update %x = %x add %expr : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[X]] sub %[[EXPR]] : memref<i32>, i32
omp.atomic.update %x = %x sub %expr : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[X]] mul %[[EXPR]] : memref<i32>, i32
omp.atomic.update %x = %x mul %expr : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[X]] div %[[EXPR]] : memref<i32>, i32
omp.atomic.update %x = %x div %expr : memref<i32>, i32
// CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] and %[[EXPRBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %xBool and %exprBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] or %[[EXPRBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %xBool or %exprBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] xor %[[EXPRBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %xBool xor %exprBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[X]] = %[[X]] shiftr %[[EXPR]] : memref<i32>, i32
omp.atomic.update %x = %x shiftr %expr : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[X]] shiftl %[[EXPR]] : memref<i32>, i32
omp.atomic.update %x = %x shiftl %expr : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[X]] max %[[EXPR]] : memref<i32>, i32
omp.atomic.update %x = %x max %expr : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[X]] min %[[EXPR]] : memref<i32>, i32
omp.atomic.update %x = %x min %expr : memref<i32>, i32
// CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] eqv %[[EXPRBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %xBool eqv %exprBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[XBOOL]] = %[[XBOOL]] neqv %[[EXPRBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %xBool neqv %exprBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] add %[[X]] : memref<i32>, i32
omp.atomic.update %x = %expr add %x : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] sub %[[X]] : memref<i32>, i32
omp.atomic.update %x = %expr sub %x : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] mul %[[X]] : memref<i32>, i32
omp.atomic.update %x = %expr mul %x : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] div %[[X]] : memref<i32>, i32
omp.atomic.update %x = %expr div %x : memref<i32>, i32
// CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] and %[[XBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %exprBool and %xBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] or %[[XBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %exprBool or %xBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] xor %[[XBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %exprBool xor %xBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] shiftr %[[X]] : memref<i32>, i32
omp.atomic.update %x = %expr shiftr %x : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] shiftl %[[X]] : memref<i32>, i32
omp.atomic.update %x = %expr shiftl %x : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] max %[[X]] : memref<i32>, i32
omp.atomic.update %x = %expr max %x : memref<i32>, i32
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] min %[[X]] : memref<i32>, i32
omp.atomic.update %x = %expr min %x : memref<i32>, i32
// CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] eqv %[[XBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %exprBool eqv %xBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[XBOOL]] = %[[EXPRBOOL]] neqv %[[XBOOL]] : memref<i1>, i1
omp.atomic.update %xBool = %exprBool neqv %xBool : memref<i1>, i1
// CHECK: omp.atomic.update %[[X]] = %[[EXPR]] add %[[X]] memory_order(seq_cst) hint(speculative) : memref<i32>, i32
omp.atomic.update %x = %expr add %x hint(speculative) memory_order(seq_cst) : memref<i32>, i32
return
}
// CHECK-LABEL: omp_sectionsop
func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
%data_var3 : memref<i32>, %redn_var : !llvm.ptr<f32>) {