[mlir][OpenMP] OpenMP Synchronization Hints stored as IntegerAttr

`hint-expression` is an IntegerAttr, because it can be a combination of multiple values from the enum `omp_sync_hint_t` (Section 2.17.12 of OpenMP 5.0)

Reviewed By: ftynse, kiranchandramohan

Differential Revision: https://reviews.llvm.org/D111360
This commit is contained in:
Shraiysh Vaishay 2021-10-12 10:47:30 +00:00 committed by Kiran Chandramohan
parent 269d0e223a
commit 7a79c6afea
5 changed files with 168 additions and 29 deletions

View File

@ -369,21 +369,6 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> {
}
// TODO: Autogenerate this from OMP.td in llvm/include/Frontend
def omp_sync_hint_none: I32EnumAttrCase<"none", 0>;
def omp_sync_hint_uncontended: I32EnumAttrCase<"uncontended", 1>;
def omp_sync_hint_contended: I32EnumAttrCase<"contended", 2>;
def omp_sync_hint_nonspeculative: I32EnumAttrCase<"nonspeculative", 3>;
def omp_sync_hint_speculative: I32EnumAttrCase<"speculative", 4>;
def SyncHintKind: I32EnumAttr<"SyncHintKind", "OpenMP Sync Hint Kind",
[omp_sync_hint_none, omp_sync_hint_uncontended, omp_sync_hint_contended,
omp_sync_hint_nonspeculative, omp_sync_hint_speculative]> {
let cppNamespace = "::mlir::omp";
let stringToSymbolFnName = "ConvertToEnum";
let symbolToStringFnName = "ConvertToString";
}
def CriticalOp : OpenMP_Op<"critical"> {
let summary = "critical construct";
let description = [{
@ -392,12 +377,12 @@ def CriticalOp : OpenMP_Op<"critical"> {
}];
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$name,
OptionalAttr<SyncHintKind>:$hint);
DefaultValuedAttr<I64Attr, "0">:$hint);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
(`(` $name^ `)`)? (`hint` `(` $hint^ `)`)? $region attr-dict
(`(` $name^ `)`)? custom<SynchronizationHint>($hint) $region attr-dict
}];
let verifier = "return ::verifyCriticalOp(*this);";

View File

@ -959,11 +959,109 @@ static LogicalResult verifyWsLoopOp(WsLoopOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// Parser, printer and verifier for Synchronization Hint (2.17.12)
//===----------------------------------------------------------------------===//
/// Parses a Synchronization Hint clause. The value of hint is an integer
/// which is a combination of different hints from `omp_sync_hint_t`.
///
/// hint-clause = `hint` `(` hint-value `)`
static ParseResult parseSynchronizationHint(OpAsmParser &parser,
IntegerAttr &hintAttr) {
if (failed(parser.parseOptionalKeyword("hint"))) {
hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
return success();
}
if (failed(parser.parseLParen()))
return failure();
StringRef hintKeyword;
int64_t hint = 0;
do {
if (failed(parser.parseKeyword(&hintKeyword)))
return failure();
if (hintKeyword == "uncontended")
hint |= 1;
else if (hintKeyword == "contended")
hint |= 2;
else if (hintKeyword == "nonspeculative")
hint |= 4;
else if (hintKeyword == "speculative")
hint |= 8;
else
return parser.emitError(parser.getCurrentLocation())
<< hintKeyword << " is not a valid hint";
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen()))
return failure();
hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
return success();
}
/// Prints a Synchronization Hint clause
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
IntegerAttr hintAttr) {
int64_t hint = hintAttr.getInt();
if (hint == 0)
return;
// Helper function to get n-th bit from the right end of `value`
auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
bool uncontended = bitn(hint, 0);
bool contended = bitn(hint, 1);
bool nonspeculative = bitn(hint, 2);
bool speculative = bitn(hint, 3);
SmallVector<StringRef> hints;
if (uncontended)
hints.push_back("uncontended");
if (contended)
hints.push_back("contended");
if (nonspeculative)
hints.push_back("nonspeculative");
if (speculative)
hints.push_back("speculative");
p << "hint(";
llvm::interleaveComma(hints, p);
p << ")";
}
/// Verifies a synchronization hint clause
static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) {
// Helper function to get n-th bit from the right end of `value`
auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
bool uncontended = bitn(hint, 0);
bool contended = bitn(hint, 1);
bool nonspeculative = bitn(hint, 2);
bool speculative = bitn(hint, 3);
if (uncontended && contended)
return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
"omp_sync_hint_contended cannot be combined";
if (nonspeculative && speculative)
return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
"omp_sync_hint_speculative cannot be combined.";
return success();
}
//===----------------------------------------------------------------------===//
// Verifier for critical construct (2.17.1)
//===----------------------------------------------------------------------===//
static LogicalResult verifyCriticalOp(CriticalOp op) {
if (!op.name().hasValue() && op.hint().hasValue() &&
(op.hint().getValue() != SyncHintKind::none))
if (failed(verifySynchronizationHint(op, op.hint()))) {
return failure();
}
if (!op.name().hasValue() && (op.hint() != 0))
return op.emitOpError() << "must specify a name unless the effect is as if "
"hint(none) is specified";
"no hint is specified";
if (op.nameAttr()) {
auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();

View File

@ -300,14 +300,8 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::LocationDescription ompLoc(
builder.saveIP(), builder.getCurrentDebugLocation());
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Constant *hint = nullptr;
if (criticalOp.hint().hasValue()) {
hint =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
static_cast<int>(criticalOp.hint().getValue()));
} else {
hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
}
llvm::Constant *hint = llvm::ConstantInt::get(
llvm::Type::getInt32Ty(llvmContext), static_cast<int>(criticalOp.hint()));
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
return success();

View File

@ -297,7 +297,7 @@ func @foo(%lb : index, %ub : index, %step : index, %mem : memref<1xf32>) {
// -----
func @omp_critical1() -> () {
// expected-error @below {{must specify a name unless the effect is as if hint(none) is specified}}
// expected-error @below {{must specify a name unless the effect is as if no hint is specified}}
omp.critical hint(nonspeculative) {
omp.terminator
}
@ -313,3 +313,35 @@ func @omp_critical2() -> () {
}
return
}
// -----
omp.critical.declare @mutex
func @omp_critical() -> () {
// expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}}
omp.critical(@mutex) hint(uncontended, contended) {
omp.terminator
}
return
}
// -----
omp.critical.declare @mutex
func @omp_critical() -> () {
// expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}}
omp.critical(@mutex) hint(nonspeculative, speculative) {
omp.terminator
}
return
}
// -----
omp.critical.declare @mutex
func @omp_critica() -> () {
// expected-error @below {{invalid_hint is not a valid hint}}
omp.critical(@mutex) hint(invalid_hint) {
omp.terminator
}
}

View File

@ -375,12 +375,42 @@ omp.critical.declare @mutex
// CHECK-LABEL: omp_critical
func @omp_critical() -> () {
// CHECK: omp.critical
omp.critical {
omp.terminator
}
// CHECK: omp.critical(@{{.*}}) hint(uncontended)
omp.critical(@mutex) hint(uncontended) {
omp.terminator
}
// CHECK: omp.critical(@{{.*}}) hint(contended)
omp.critical(@mutex) hint(contended) {
omp.terminator
}
// CHECK: omp.critical(@{{.*}}) hint(nonspeculative)
omp.critical(@mutex) hint(nonspeculative) {
omp.terminator
}
// CHECK: omp.critical(@{{.*}}) hint(uncontended, nonspeculative)
omp.critical(@mutex) hint(uncontended, nonspeculative) {
omp.terminator
}
// CHECK: omp.critical(@{{.*}}) hint(contended, nonspeculative)
omp.critical(@mutex) hint(nonspeculative, contended) {
omp.terminator
}
// CHECK: omp.critical(@{{.*}}) hint(speculative)
omp.critical(@mutex) hint(speculative) {
omp.terminator
}
// CHECK: omp.critical(@{{.*}}) hint(uncontended, speculative)
omp.critical(@mutex) hint(uncontended, speculative) {
omp.terminator
}
// CHECK: omp.critical(@{{.*}}) hint(contended, speculative)
omp.critical(@mutex) hint(speculative, contended) {
omp.terminator
}
return
}