forked from OSchip/llvm-project
[MLIR,OpenMP] Lowering of parallel operation: proc_bind clause 2/n
This patch adds the translation of the proc_bind clause in a parallel operation. The values that can be specified for the proc_bind clause are specified in the OMP.td tablegen file in the llvm/Frontend/OpenMP directory. From this single source of truth enumeration for proc_bind is generated in llvm and mlir (used in specification of the parallel Operation in the OpenMP dialect). A function to return the enum value from the string representation is also generated. A new header file (DirectiveEmitter.h) containing definitions of classes directive, clause, clauseval etc is created so that it can be used in mlir as well. Reviewers: clementval, jdoerfert, DavidTruby Differential Revision: https://reviews.llvm.org/D84347
This commit is contained in:
parent
6b3dc96e59
commit
e6c5e6efd0
|
@ -51,6 +51,21 @@ class DirectiveLanguage {
|
|||
string flangClauseBaseClass = "";
|
||||
}
|
||||
|
||||
// Information about values accepted by enum-like clauses
|
||||
class ClauseVal<string n, int v, bit uv> {
|
||||
// Name of the clause value.
|
||||
string name = n;
|
||||
|
||||
// Integer value of the clause.
|
||||
int value = v;
|
||||
|
||||
// Can user specify this value?
|
||||
bit isUserValue = uv;
|
||||
|
||||
// Set clause value used by default when unknown.
|
||||
bit isDefault = 0;
|
||||
}
|
||||
|
||||
// Information about a specific clause.
|
||||
class Clause<string c> {
|
||||
// Name of the clause.
|
||||
|
@ -75,11 +90,17 @@ class Clause<string c> {
|
|||
// If set to 1, value is optional. Not optional by default.
|
||||
bit isValueOptional = 0;
|
||||
|
||||
// Name of enum when there is a list of allowed clause values.
|
||||
string enumClauseValue = "";
|
||||
|
||||
// List of allowed clause values
|
||||
list<ClauseVal> allowedClauseValues = [];
|
||||
|
||||
// Is clause implicit? If clause is set as implicit, the default kind will
|
||||
// be return in get<LanguageName>ClauseKind instead of their own kind.
|
||||
bit isImplicit = 0;
|
||||
|
||||
// Set directive used by default when unknown. Function returning the kind
|
||||
// Set clause used by default when unknown. Function returning the kind
|
||||
// of enumeration will use this clause as the default.
|
||||
bit isDefault = 0;
|
||||
}
|
||||
|
|
|
@ -99,9 +99,22 @@ def OMPC_CopyPrivate : Clause<"copyprivate"> {
|
|||
let clangClass = "OMPCopyprivateClause";
|
||||
let flangClassValue = "OmpObjectList";
|
||||
}
|
||||
def OMP_PROC_BIND_master : ClauseVal<"master",2,1> {}
|
||||
def OMP_PROC_BIND_close : ClauseVal<"close",3,1> {}
|
||||
def OMP_PROC_BIND_spread : ClauseVal<"spread",4,1> {}
|
||||
def OMP_PROC_BIND_default : ClauseVal<"default",5,0> {}
|
||||
def OMP_PROC_BIND_unknown : ClauseVal<"unknown",6,0> { let isDefault = 1; }
|
||||
def OMPC_ProcBind : Clause<"proc_bind"> {
|
||||
let clangClass = "OMPProcBindClause";
|
||||
let flangClass = "OmpProcBindClause";
|
||||
let enumClauseValue = "ProcBindKind";
|
||||
let allowedClauseValues = [
|
||||
OMP_PROC_BIND_master,
|
||||
OMP_PROC_BIND_close,
|
||||
OMP_PROC_BIND_spread,
|
||||
OMP_PROC_BIND_default,
|
||||
OMP_PROC_BIND_unknown
|
||||
];
|
||||
}
|
||||
def OMPC_Schedule : Clause<"schedule"> {
|
||||
let clangClass = "OMPScheduleClause";
|
||||
|
|
|
@ -68,16 +68,6 @@ enum class DefaultKind {
|
|||
constexpr auto Enum = omp::DefaultKind::Enum;
|
||||
#include "llvm/Frontend/OpenMP/OMPKinds.def"
|
||||
|
||||
/// IDs for the different proc bind kinds.
|
||||
enum class ProcBindKind {
|
||||
#define OMP_PROC_BIND_KIND(Enum, Str, Value) Enum = Value,
|
||||
#include "llvm/Frontend/OpenMP/OMPKinds.def"
|
||||
};
|
||||
|
||||
#define OMP_PROC_BIND_KIND(Enum, ...) \
|
||||
constexpr auto Enum = omp::ProcBindKind::Enum;
|
||||
#include "llvm/Frontend/OpenMP/OMPKinds.def"
|
||||
|
||||
/// IDs for all omp runtime library ident_t flag encodings (see
|
||||
/// their defintion in openmp/runtime/src/kmp.h).
|
||||
enum class IdentFlag {
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
#ifndef LLVM_TABLEGEN_DIRECTIVEEMITTER_H
|
||||
#define LLVM_TABLEGEN_DIRECTIVEEMITTER_H
|
||||
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
namespace llvm {
|
||||
|
||||
// Wrapper class that contains DirectiveLanguage's information defined in
|
||||
// DirectiveBase.td and provides helper methods for accessing it.
|
||||
class DirectiveLanguage {
|
||||
public:
|
||||
explicit DirectiveLanguage(const llvm::Record *Def) : Def(Def) {}
|
||||
|
||||
StringRef getName() const { return Def->getValueAsString("name"); }
|
||||
|
||||
StringRef getCppNamespace() const {
|
||||
return Def->getValueAsString("cppNamespace");
|
||||
}
|
||||
|
||||
StringRef getDirectivePrefix() const {
|
||||
return Def->getValueAsString("directivePrefix");
|
||||
}
|
||||
|
||||
StringRef getClausePrefix() const {
|
||||
return Def->getValueAsString("clausePrefix");
|
||||
}
|
||||
|
||||
StringRef getIncludeHeader() const {
|
||||
return Def->getValueAsString("includeHeader");
|
||||
}
|
||||
|
||||
StringRef getClauseEnumSetClass() const {
|
||||
return Def->getValueAsString("clauseEnumSetClass");
|
||||
}
|
||||
|
||||
StringRef getFlangClauseBaseClass() const {
|
||||
return Def->getValueAsString("flangClauseBaseClass");
|
||||
}
|
||||
|
||||
bool hasMakeEnumAvailableInNamespace() const {
|
||||
return Def->getValueAsBit("makeEnumAvailableInNamespace");
|
||||
}
|
||||
|
||||
bool hasEnableBitmaskEnumInNamespace() const {
|
||||
return Def->getValueAsBit("enableBitmaskEnumInNamespace");
|
||||
}
|
||||
|
||||
private:
|
||||
const llvm::Record *Def;
|
||||
};
|
||||
|
||||
// Base record class used for Directive and Clause class defined in
|
||||
// DirectiveBase.td.
|
||||
class BaseRecord {
|
||||
public:
|
||||
explicit BaseRecord(const llvm::Record *Def) : Def(Def) {}
|
||||
|
||||
StringRef getName() const { return Def->getValueAsString("name"); }
|
||||
|
||||
StringRef getAlternativeName() const {
|
||||
return Def->getValueAsString("alternativeName");
|
||||
}
|
||||
|
||||
// Returns the name of the directive formatted for output. Whitespace are
|
||||
// replaced with underscores.
|
||||
std::string getFormattedName() {
|
||||
StringRef Name = Def->getValueAsString("name");
|
||||
std::string N = Name.str();
|
||||
std::replace(N.begin(), N.end(), ' ', '_');
|
||||
return N;
|
||||
}
|
||||
|
||||
bool isDefault() const { return Def->getValueAsBit("isDefault"); }
|
||||
|
||||
protected:
|
||||
const llvm::Record *Def;
|
||||
};
|
||||
|
||||
// Wrapper class that contains a Directive's information defined in
|
||||
// DirectiveBase.td and provides helper methods for accessing it.
|
||||
class Directive : public BaseRecord {
|
||||
public:
|
||||
explicit Directive(const llvm::Record *Def) : BaseRecord(Def) {}
|
||||
|
||||
std::vector<Record *> getAllowedClauses() const {
|
||||
return Def->getValueAsListOfDefs("allowedClauses");
|
||||
}
|
||||
|
||||
std::vector<Record *> getAllowedOnceClauses() const {
|
||||
return Def->getValueAsListOfDefs("allowedOnceClauses");
|
||||
}
|
||||
|
||||
std::vector<Record *> getAllowedExclusiveClauses() const {
|
||||
return Def->getValueAsListOfDefs("allowedExclusiveClauses");
|
||||
}
|
||||
|
||||
std::vector<Record *> getRequiredClauses() const {
|
||||
return Def->getValueAsListOfDefs("requiredClauses");
|
||||
}
|
||||
};
|
||||
|
||||
// Wrapper class that contains Clause's information defined in DirectiveBase.td
|
||||
// and provides helper methods for accessing it.
|
||||
class Clause : public BaseRecord {
|
||||
public:
|
||||
explicit Clause(const llvm::Record *Def) : BaseRecord(Def) {}
|
||||
|
||||
// Optional field.
|
||||
StringRef getClangClass() const {
|
||||
return Def->getValueAsString("clangClass");
|
||||
}
|
||||
|
||||
// Optional field.
|
||||
StringRef getFlangClass() const {
|
||||
return Def->getValueAsString("flangClass");
|
||||
}
|
||||
|
||||
// Optional field.
|
||||
StringRef getFlangClassValue() const {
|
||||
return Def->getValueAsString("flangClassValue");
|
||||
}
|
||||
|
||||
// Get the formatted name for Flang parser class. The generic formatted class
|
||||
// name is constructed from the name were the first letter of each word is
|
||||
// captitalized and the underscores are removed.
|
||||
// ex: async -> Async
|
||||
// num_threads -> NumThreads
|
||||
std::string getFormattedParserClassName() {
|
||||
StringRef Name = Def->getValueAsString("name");
|
||||
std::string N = Name.str();
|
||||
bool Cap = true;
|
||||
std::transform(N.begin(), N.end(), N.begin(), [&Cap](unsigned char C) {
|
||||
if (Cap == true) {
|
||||
C = llvm::toUpper(C);
|
||||
Cap = false;
|
||||
} else if (C == '_') {
|
||||
Cap = true;
|
||||
}
|
||||
return C;
|
||||
});
|
||||
N.erase(std::remove(N.begin(), N.end(), '_'), N.end());
|
||||
return N;
|
||||
}
|
||||
|
||||
// Optional field.
|
||||
StringRef getEnumName() const {
|
||||
return Def->getValueAsString("enumClauseValue");
|
||||
}
|
||||
|
||||
std::vector<Record *> getClauseVals() const {
|
||||
return Def->getValueAsListOfDefs("allowedClauseValues");
|
||||
}
|
||||
|
||||
bool isValueOptional() const { return Def->getValueAsBit("isValueOptional"); }
|
||||
|
||||
bool isImplict() const { return Def->getValueAsBit("isImplicit"); }
|
||||
};
|
||||
|
||||
// Wrapper class that contains VersionedClause's information defined in
|
||||
// DirectiveBase.td and provides helper methods for accessing it.
|
||||
class VersionedClause {
|
||||
public:
|
||||
explicit VersionedClause(const llvm::Record *Def) : Def(Def) {}
|
||||
|
||||
// Return the specific clause record wrapped in the Clause class.
|
||||
Clause getClause() const { return Clause{Def->getValueAsDef("clause")}; }
|
||||
|
||||
int64_t getMinVersion() const { return Def->getValueAsInt("minVersion"); }
|
||||
|
||||
int64_t getMaxVersion() const { return Def->getValueAsInt("maxVersion"); }
|
||||
|
||||
private:
|
||||
const llvm::Record *Def;
|
||||
};
|
||||
|
||||
class ClauseVal : public BaseRecord {
|
||||
public:
|
||||
explicit ClauseVal(const llvm::Record *Def) : BaseRecord(Def) {}
|
||||
|
||||
int getValue() const { return Def->getValueAsInt("value"); }
|
||||
|
||||
bool isUserVisible() const { return Def->getValueAsBit("isUserValue"); }
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
#endif
|
|
@ -15,9 +15,20 @@ def TestDirectiveLanguage : DirectiveLanguage {
|
|||
let flangClauseBaseClass = "TdlClause";
|
||||
}
|
||||
|
||||
def TDLCV_vala : ClauseVal<"vala",1,1> {}
|
||||
def TDLCV_valb : ClauseVal<"valb",2,1> {}
|
||||
def TDLCV_valc : ClauseVal<"valc",3,0> { let isDefault = 1; }
|
||||
|
||||
def TDLC_ClauseA : Clause<"clausea"> {
|
||||
let flangClass = "TdlClauseA";
|
||||
let enumClauseValue = "AKind";
|
||||
let allowedClauseValues = [
|
||||
TDLCV_vala,
|
||||
TDLCV_valb,
|
||||
TDLCV_valc
|
||||
];
|
||||
}
|
||||
|
||||
def TDLC_ClauseB : Clause<"clauseb"> {
|
||||
let flangClassValue = "IntExpr";
|
||||
let isValueOptional = 1;
|
||||
|
@ -61,6 +72,16 @@ def TDL_DirA : Directive<"dira"> {
|
|||
// CHECK-NEXT: constexpr auto TDLC_clausea = llvm::tdl::Clause::TDLC_clausea;
|
||||
// CHECK-NEXT: constexpr auto TDLC_clauseb = llvm::tdl::Clause::TDLC_clauseb;
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: enum class AKind {
|
||||
// CHECK-NEXT: TDLCV_vala=1,
|
||||
// CHECK-NEXT: TDLCV_valb=2,
|
||||
// CHECK-NEXT: TDLCV_valc=3,
|
||||
// CHECK-NEXT: };
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: constexpr auto TDLCV_vala = llvm::tdl::AKind::TDLCV_vala;
|
||||
// CHECK-NEXT: constexpr auto TDLCV_valb = llvm::tdl::AKind::TDLCV_valb;
|
||||
// CHECK-NEXT: constexpr auto TDLCV_valc = llvm::tdl::AKind::TDLCV_valc;
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: // Enumeration helper functions
|
||||
// CHECK-NEXT: Directive getTdlDirectiveKind(llvm::StringRef Str);
|
||||
// CHECK-EMPTY:
|
||||
|
@ -73,6 +94,8 @@ def TDL_DirA : Directive<"dira"> {
|
|||
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
|
||||
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: AKind getAKind(StringRef);
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } // namespace tdl
|
||||
// CHECK-NEXT: } // namespace llvm
|
||||
// CHECK-NEXT: #endif // LLVM_Tdl_INC
|
||||
|
@ -116,6 +139,14 @@ def TDL_DirA : Directive<"dira"> {
|
|||
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Clause kind");
|
||||
// IMPL-NEXT: }
|
||||
// IMPL-EMPTY:
|
||||
// IMPL-NEXT: AKind llvm::tdl::getAKind(llvm::StringRef Str) {
|
||||
// IMPL-NEXT: return llvm::StringSwitch<AKind>(Str)
|
||||
// IMPL-NEXT: .Case("vala",TDLCV_vala)
|
||||
// IMPL-NEXT: .Case("valb",TDLCV_valb)
|
||||
// IMPL-NEXT: .Case("valc",TDLCV_valc)
|
||||
// IMPL-NEXT: .Default(TDLCV_valc);
|
||||
// IMPL-NEXT: }
|
||||
// IMPL-EMPTY:
|
||||
// IMPL-NEXT: bool llvm::tdl::isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) {
|
||||
// IMPL-NEXT: assert(unsigned(D) <= llvm::tdl::Directive_enumSize);
|
||||
// IMPL-NEXT: assert(unsigned(C) <= llvm::tdl::Clause_enumSize);
|
||||
|
|
|
@ -11,15 +11,14 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/TableGen/DirectiveEmitter.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace {
|
||||
|
@ -41,165 +40,6 @@ private:
|
|||
|
||||
namespace llvm {
|
||||
|
||||
// Wrapper class that contains DirectiveLanguage's information defined in
|
||||
// DirectiveBase.td and provides helper methods for accessing it.
|
||||
class DirectiveLanguage {
|
||||
public:
|
||||
explicit DirectiveLanguage(const llvm::Record *Def) : Def(Def) {}
|
||||
|
||||
StringRef getName() const { return Def->getValueAsString("name"); }
|
||||
|
||||
StringRef getCppNamespace() const {
|
||||
return Def->getValueAsString("cppNamespace");
|
||||
}
|
||||
|
||||
StringRef getDirectivePrefix() const {
|
||||
return Def->getValueAsString("directivePrefix");
|
||||
}
|
||||
|
||||
StringRef getClausePrefix() const {
|
||||
return Def->getValueAsString("clausePrefix");
|
||||
}
|
||||
|
||||
StringRef getIncludeHeader() const {
|
||||
return Def->getValueAsString("includeHeader");
|
||||
}
|
||||
|
||||
StringRef getClauseEnumSetClass() const {
|
||||
return Def->getValueAsString("clauseEnumSetClass");
|
||||
}
|
||||
|
||||
StringRef getFlangClauseBaseClass() const {
|
||||
return Def->getValueAsString("flangClauseBaseClass");
|
||||
}
|
||||
|
||||
bool hasMakeEnumAvailableInNamespace() const {
|
||||
return Def->getValueAsBit("makeEnumAvailableInNamespace");
|
||||
}
|
||||
|
||||
bool hasEnableBitmaskEnumInNamespace() const {
|
||||
return Def->getValueAsBit("enableBitmaskEnumInNamespace");
|
||||
}
|
||||
|
||||
private:
|
||||
const llvm::Record *Def;
|
||||
};
|
||||
|
||||
// Base record class used for Directive and Clause class defined in
|
||||
// DirectiveBase.td.
|
||||
class BaseRecord {
|
||||
public:
|
||||
explicit BaseRecord(const llvm::Record *Def) : Def(Def) {}
|
||||
|
||||
StringRef getName() const { return Def->getValueAsString("name"); }
|
||||
|
||||
StringRef getAlternativeName() const {
|
||||
return Def->getValueAsString("alternativeName");
|
||||
}
|
||||
|
||||
// Returns the name of the directive formatted for output. Whitespace are
|
||||
// replaced with underscores.
|
||||
std::string getFormattedName() {
|
||||
StringRef Name = Def->getValueAsString("name");
|
||||
std::string N = Name.str();
|
||||
std::replace(N.begin(), N.end(), ' ', '_');
|
||||
return N;
|
||||
}
|
||||
|
||||
bool isDefault() const { return Def->getValueAsBit("isDefault"); }
|
||||
|
||||
protected:
|
||||
const llvm::Record *Def;
|
||||
};
|
||||
|
||||
// Wrapper class that contains a Directive's information defined in
|
||||
// DirectiveBase.td and provides helper methods for accessing it.
|
||||
class Directive : public BaseRecord {
|
||||
public:
|
||||
explicit Directive(const llvm::Record *Def) : BaseRecord(Def) {}
|
||||
|
||||
std::vector<Record *> getAllowedClauses() const {
|
||||
return Def->getValueAsListOfDefs("allowedClauses");
|
||||
}
|
||||
|
||||
std::vector<Record *> getAllowedOnceClauses() const {
|
||||
return Def->getValueAsListOfDefs("allowedOnceClauses");
|
||||
}
|
||||
|
||||
std::vector<Record *> getAllowedExclusiveClauses() const {
|
||||
return Def->getValueAsListOfDefs("allowedExclusiveClauses");
|
||||
}
|
||||
|
||||
std::vector<Record *> getRequiredClauses() const {
|
||||
return Def->getValueAsListOfDefs("requiredClauses");
|
||||
}
|
||||
};
|
||||
|
||||
// Wrapper class that contains Clause's information defined in DirectiveBase.td
|
||||
// and provides helper methods for accessing it.
|
||||
class Clause : public BaseRecord {
|
||||
public:
|
||||
explicit Clause(const llvm::Record *Def) : BaseRecord(Def) {}
|
||||
|
||||
// Optional field.
|
||||
StringRef getClangClass() const {
|
||||
return Def->getValueAsString("clangClass");
|
||||
}
|
||||
|
||||
// Optional field.
|
||||
StringRef getFlangClass() const {
|
||||
return Def->getValueAsString("flangClass");
|
||||
}
|
||||
|
||||
// Optional field.
|
||||
StringRef getFlangClassValue() const {
|
||||
return Def->getValueAsString("flangClassValue");
|
||||
}
|
||||
|
||||
// Get the formatted name for Flang parser class. The generic formatted class
|
||||
// name is constructed from the name were the first letter of each word is
|
||||
// captitalized and the underscores are removed.
|
||||
// ex: async -> Async
|
||||
// num_threads -> NumThreads
|
||||
std::string getFormattedParserClassName() {
|
||||
StringRef Name = Def->getValueAsString("name");
|
||||
std::string N = Name.str();
|
||||
bool Cap = true;
|
||||
std::transform(N.begin(), N.end(), N.begin(), [&Cap](unsigned char C) {
|
||||
if (Cap == true) {
|
||||
C = llvm::toUpper(C);
|
||||
Cap = false;
|
||||
} else if (C == '_') {
|
||||
Cap = true;
|
||||
}
|
||||
return C;
|
||||
});
|
||||
N.erase(std::remove(N.begin(), N.end(), '_'), N.end());
|
||||
return N;
|
||||
}
|
||||
|
||||
bool isValueOptional() const { return Def->getValueAsBit("isValueOptional"); }
|
||||
|
||||
bool isImplict() const { return Def->getValueAsBit("isImplicit"); }
|
||||
};
|
||||
|
||||
// Wrapper class that contains VersionedClause's information defined in
|
||||
// DirectiveBase.td and provides helper methods for accessing it.
|
||||
class VersionedClause {
|
||||
public:
|
||||
explicit VersionedClause(const llvm::Record *Def) : Def(Def) {}
|
||||
|
||||
// Return the specific clause record wrapped in the Clause class.
|
||||
Clause getClause() const { return Clause{Def->getValueAsDef("clause")}; }
|
||||
|
||||
int64_t getMinVersion() const { return Def->getValueAsInt("minVersion"); }
|
||||
|
||||
int64_t getMaxVersion() const { return Def->getValueAsInt("maxVersion"); }
|
||||
|
||||
private:
|
||||
const llvm::Record *Def;
|
||||
};
|
||||
|
||||
// Generate enum class
|
||||
void GenerateEnumClass(const std::vector<Record *> &Records, raw_ostream &OS,
|
||||
StringRef Enum, StringRef Prefix,
|
||||
|
@ -231,6 +71,46 @@ void GenerateEnumClass(const std::vector<Record *> &Records, raw_ostream &OS,
|
|||
}
|
||||
}
|
||||
|
||||
// Generate enums for values that clauses can take.
|
||||
// Also generate function declarations for get<Enum>Name(StringRef Str).
|
||||
void GenerateEnumClauseVal(const std::vector<Record *> &Records,
|
||||
raw_ostream &OS, DirectiveLanguage &DirLang,
|
||||
std::string &EnumHelperFuncs) {
|
||||
for (const auto &R : Records) {
|
||||
Clause C{R};
|
||||
const auto &ClauseVals = C.getClauseVals();
|
||||
if (ClauseVals.size() <= 0)
|
||||
continue;
|
||||
|
||||
const auto &EnumName = C.getEnumName();
|
||||
if (EnumName.size() == 0) {
|
||||
PrintError("enumClauseValue field not set in Clause" +
|
||||
C.getFormattedName() + ".");
|
||||
return;
|
||||
}
|
||||
|
||||
OS << "\n";
|
||||
OS << "enum class " << EnumName << " {\n";
|
||||
for (const auto &CV : ClauseVals) {
|
||||
ClauseVal CVal{CV};
|
||||
OS << " " << CV->getName() << "=" << CVal.getValue() << ",\n";
|
||||
}
|
||||
OS << "};\n";
|
||||
|
||||
if (DirLang.hasMakeEnumAvailableInNamespace()) {
|
||||
OS << "\n";
|
||||
for (const auto &CV : ClauseVals) {
|
||||
OS << "constexpr auto " << CV->getName() << " = "
|
||||
<< "llvm::" << DirLang.getCppNamespace() << "::" << EnumName
|
||||
<< "::" << CV->getName() << ";\n";
|
||||
}
|
||||
EnumHelperFuncs += (llvm::Twine(EnumName) + llvm::Twine(" get") +
|
||||
llvm::Twine(EnumName) + llvm::Twine("(StringRef);\n"))
|
||||
.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the declaration section for the enumeration in the directive
|
||||
// language
|
||||
void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
|
||||
|
@ -273,6 +153,10 @@ void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
|
|||
const auto &Clauses = Records.getAllDerivedDefinitions("Clause");
|
||||
GenerateEnumClass(Clauses, OS, "Clause", DirLang.getClausePrefix(), DirLang);
|
||||
|
||||
// Emit ClauseVal enumeration
|
||||
std::string EnumHelperFuncs;
|
||||
GenerateEnumClauseVal(Clauses, OS, DirLang, EnumHelperFuncs);
|
||||
|
||||
// Generic function signatures
|
||||
OS << "\n";
|
||||
OS << "// Enumeration helper functions\n";
|
||||
|
@ -292,6 +176,10 @@ void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
|
|||
OS << "bool isAllowedClauseForDirective(Directive D, "
|
||||
<< "Clause C, unsigned Version);\n";
|
||||
OS << "\n";
|
||||
if (EnumHelperFuncs.length() > 0) {
|
||||
OS << EnumHelperFuncs;
|
||||
OS << "\n";
|
||||
}
|
||||
|
||||
// Closing namespaces
|
||||
for (auto Ns : llvm::reverse(Namespaces))
|
||||
|
@ -336,7 +224,7 @@ void GenerateGetKind(const std::vector<Record *> &Records, raw_ostream &OS,
|
|||
});
|
||||
|
||||
if (DefaultIt == Records.end()) {
|
||||
PrintError("A least one " + Enum + " must be defined as default.");
|
||||
PrintError("At least one " + Enum + " must be defined as default.");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -361,6 +249,49 @@ void GenerateGetKind(const std::vector<Record *> &Records, raw_ostream &OS,
|
|||
OS << "}\n";
|
||||
}
|
||||
|
||||
// Generate function implementation for get<ClauseVal>Kind(StringRef Str)
|
||||
void GenerateGetKindClauseVal(const std::vector<Record *> &Records,
|
||||
raw_ostream &OS, StringRef Namespace) {
|
||||
|
||||
for (const auto &R : Records) {
|
||||
Clause C{R};
|
||||
const auto &ClauseVals = C.getClauseVals();
|
||||
if (ClauseVals.size() <= 0)
|
||||
continue;
|
||||
|
||||
auto DefaultIt =
|
||||
std::find_if(ClauseVals.begin(), ClauseVals.end(), [](Record *CV) {
|
||||
return CV->getValueAsBit("isDefault") == true;
|
||||
});
|
||||
|
||||
if (DefaultIt == ClauseVals.end()) {
|
||||
PrintError("At least one val in Clause " + C.getFormattedName() +
|
||||
" must be defined as default.");
|
||||
return;
|
||||
}
|
||||
const auto DefaultName = (*DefaultIt)->getName();
|
||||
|
||||
const auto &EnumName = C.getEnumName();
|
||||
if (EnumName.size() == 0) {
|
||||
PrintError("enumClauseValue field not set in Clause" +
|
||||
C.getFormattedName() + ".");
|
||||
return;
|
||||
}
|
||||
|
||||
OS << "\n";
|
||||
OS << EnumName << " llvm::" << Namespace << "::get" << EnumName
|
||||
<< "(llvm::StringRef Str) {\n";
|
||||
OS << " return llvm::StringSwitch<" << EnumName << ">(Str)\n";
|
||||
for (const auto &CV : ClauseVals) {
|
||||
ClauseVal CVal{CV};
|
||||
OS << " .Case(\"" << CVal.getFormattedName() << "\"," << CV->getName()
|
||||
<< ")\n";
|
||||
}
|
||||
OS << " .Default(" << DefaultName << ");\n";
|
||||
OS << "}\n";
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
|
||||
raw_ostream &OS, StringRef DirectiveName,
|
||||
DirectiveLanguage &DirLang,
|
||||
|
@ -672,6 +603,9 @@ void EmitDirectivesImpl(RecordKeeper &Records, raw_ostream &OS) {
|
|||
// getClauseName(Clause Kind)
|
||||
GenerateGetName(Clauses, OS, "Clause", DirLang, DirLang.getClausePrefix());
|
||||
|
||||
// get<ClauseVal>Kind(StringRef Str)
|
||||
GenerateGetKindClauseVal(Clauses, OS, DirLang.getCppNamespace());
|
||||
|
||||
// isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
|
||||
GenerateIsAllowedClause(Directives, OS, DirLang);
|
||||
}
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenMP/OMP.td)
|
||||
mlir_tablegen(OmpCommon.td --gen-directive-decl)
|
||||
add_public_tablegen_target(omp_common_td)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS OpenMPOps.td)
|
||||
mlir_tablegen(OpenMPOpsDialect.h.inc -gen-dialect-decls -dialect=omp)
|
||||
mlir_tablegen(OpenMPOps.h.inc -gen-op-decls)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#define OPENMP_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/OpenMP/OmpCommon.td"
|
||||
|
||||
def OpenMP_Dialect : Dialect {
|
||||
let name = "omp";
|
||||
|
@ -42,18 +43,6 @@ def ClauseDefault : StrEnumAttr<
|
|||
let cppNamespace = "::mlir::omp";
|
||||
}
|
||||
|
||||
// Possible values for the proc_bind clause
|
||||
def ClauseProcMaster : StrEnumAttrCase<"master">;
|
||||
def ClauseProcClose : StrEnumAttrCase<"close">;
|
||||
def ClauseProcSpread : StrEnumAttrCase<"spread">;
|
||||
|
||||
def ClauseProcBind : StrEnumAttr<
|
||||
"ClauseProcBind",
|
||||
"procbind clause",
|
||||
[ClauseProcMaster, ClauseProcClose, ClauseProcSpread]> {
|
||||
let cppNamespace = "::mlir::omp";
|
||||
}
|
||||
|
||||
def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
|
||||
let summary = "parallel construct";
|
||||
let description = [{
|
||||
|
@ -87,7 +76,7 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
|
|||
Variadic<AnyType>:$firstprivate_vars,
|
||||
Variadic<AnyType>:$shared_vars,
|
||||
Variadic<AnyType>:$copyin_vars,
|
||||
OptionalAttr<ClauseProcBind>:$proc_bind_val);
|
||||
OptionalAttr<ProcBindKind>:$proc_bind_val);
|
||||
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
||||
|
|
|
@ -408,32 +408,31 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
|
|||
blockMapping[&bb] = llvmBB;
|
||||
}
|
||||
|
||||
// Then, convert blocks one by one in topological order to ensure
|
||||
// defs are converted before uses.
|
||||
llvm::SetVector<Block *> blocks = topologicalSort(region);
|
||||
for (auto indexedBB : llvm::enumerate(blocks)) {
|
||||
Block *bb = indexedBB.value();
|
||||
llvm::BasicBlock *curLLVMBB = blockMapping[bb];
|
||||
if (bb->isEntryBlock())
|
||||
codeGenIPBBTI->setSuccessor(0, curLLVMBB);
|
||||
// Then, convert blocks one by one in topological order to ensure
|
||||
// defs are converted before uses.
|
||||
llvm::SetVector<Block *> blocks = topologicalSort(region);
|
||||
for (auto indexedBB : llvm::enumerate(blocks)) {
|
||||
Block *bb = indexedBB.value();
|
||||
llvm::BasicBlock *curLLVMBB = blockMapping[bb];
|
||||
if (bb->isEntryBlock())
|
||||
codeGenIPBBTI->setSuccessor(0, curLLVMBB);
|
||||
|
||||
// TODO: Error not returned up the hierarchy
|
||||
if (failed(
|
||||
convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
|
||||
return;
|
||||
// TODO: Error not returned up the hierarchy
|
||||
if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
|
||||
return;
|
||||
|
||||
// If this block has the terminator then add a jump to
|
||||
// continuation bb
|
||||
for (auto &op : *bb) {
|
||||
if (isa<omp::TerminatorOp>(op)) {
|
||||
builder.SetInsertPoint(curLLVMBB);
|
||||
builder.CreateBr(&continuationIP);
|
||||
}
|
||||
// If this block has the terminator then add a jump to
|
||||
// continuation bb
|
||||
for (auto &op : *bb) {
|
||||
if (isa<omp::TerminatorOp>(op)) {
|
||||
builder.SetInsertPoint(curLLVMBB);
|
||||
builder.CreateBr(&continuationIP);
|
||||
}
|
||||
}
|
||||
// Finally, after all blocks have been traversed and values mapped,
|
||||
// connect the PHI nodes to the results of preceding blocks.
|
||||
connectPHINodes(region, valueMapping, blockMapping);
|
||||
}
|
||||
// Finally, after all blocks have been traversed and values mapped,
|
||||
// connect the PHI nodes to the results of preceding blocks.
|
||||
connectPHINodes(region, valueMapping, blockMapping);
|
||||
};
|
||||
|
||||
// TODO: Perform appropriate actions according to the data-sharing
|
||||
|
@ -451,23 +450,24 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
|
|||
// called for variables which have destructors/finalizers.
|
||||
auto finiCB = [&](InsertPointTy codeGenIP) {};
|
||||
|
||||
// TODO: The various operands of parallel operation are not handled.
|
||||
// Parallel operation is created with some default options for now.
|
||||
llvm::Value *ifCond = nullptr;
|
||||
if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
|
||||
ifCond = valueMapping.lookup(ifExprVar);
|
||||
llvm::Value *numThreads = nullptr;
|
||||
if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
|
||||
numThreads = valueMapping.lookup(numThreadsVar);
|
||||
llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
|
||||
if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
|
||||
pbKind = llvm::omp::getProcBindKind(bind.getValue());
|
||||
// TODO: Is the Parallel construct cancellable?
|
||||
bool isCancellable = false;
|
||||
// TODO: Determine the actual alloca insertion point, e.g., the function
|
||||
// entry or the alloca insertion point as provided by the body callback
|
||||
// above.
|
||||
llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
|
||||
builder.restoreIP(ompBuilder->CreateParallel(
|
||||
builder, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads,
|
||||
llvm::omp::OMP_PROC_BIND_default, isCancellable));
|
||||
builder.restoreIP(
|
||||
ompBuilder->CreateParallel(builder, allocaIP, bodyGenCB, privCB, finiCB,
|
||||
ifCond, numThreads, pbKind, isCancellable));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -175,3 +175,34 @@ llvm.func @test_omp_parallel_if_1(%arg0: !llvm.i32) -> () {
|
|||
|
||||
// CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]]
|
||||
// CHECK: call void @__kmpc_barrier
|
||||
|
||||
// CHECK-LABEL: define void @test_omp_parallel_3()
|
||||
llvm.func @test_omp_parallel_3() -> () {
|
||||
// CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
|
||||
// CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_1]], i32 2)
|
||||
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_1:.*]] to {{.*}}
|
||||
omp.parallel proc_bind(master) {
|
||||
omp.barrier
|
||||
omp.terminator
|
||||
}
|
||||
// CHECK: [[OMP_THREAD_3_2:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
|
||||
// CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_2]], i32 3)
|
||||
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_2:.*]] to {{.*}}
|
||||
omp.parallel proc_bind(close) {
|
||||
omp.barrier
|
||||
omp.terminator
|
||||
}
|
||||
// CHECK: [[OMP_THREAD_3_3:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
|
||||
// CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_3]], i32 4)
|
||||
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_3:.*]] to {{.*}}
|
||||
omp.parallel proc_bind(spread) {
|
||||
omp.barrier
|
||||
omp.terminator
|
||||
}
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK: define internal void @[[OMP_OUTLINED_FN_3_3]]
|
||||
// CHECK: define internal void @[[OMP_OUTLINED_FN_3_2]]
|
||||
// CHECK: define internal void @[[OMP_OUTLINED_FN_3_1]]
|
||||
|
|
|
@ -14,6 +14,7 @@ add_tablegen(mlir-tblgen MLIR
|
|||
OpDocGen.cpp
|
||||
OpFormatGen.cpp
|
||||
OpInterfacesGen.cpp
|
||||
OpenMPCommonGen.cpp
|
||||
PassGen.cpp
|
||||
PassDocGen.cpp
|
||||
RewriterGen.cpp
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
//===========- OpenMPCommonGen.cpp - OpenMP common info generator -===========//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// OpenMPCommonGen generates utility information from the single OpenMP source
|
||||
// of truth in llvm/lib/Frontend/OpenMP.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/TableGen/DirectiveEmitter.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using llvm::Clause;
|
||||
using llvm::ClauseVal;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::RecordKeeper;
|
||||
using llvm::Twine;
|
||||
|
||||
static bool emitDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
const auto &clauses = recordKeeper.getAllDerivedDefinitions("Clause");
|
||||
|
||||
for (const auto &r : clauses) {
|
||||
Clause c{r};
|
||||
const auto &clauseVals = c.getClauseVals();
|
||||
if (clauseVals.size() <= 0)
|
||||
continue;
|
||||
|
||||
const auto enumName = c.getEnumName();
|
||||
assert(enumName.size() != 0 && "enumClauseValue field not set.");
|
||||
|
||||
std::vector<std::string> cvDefs;
|
||||
for (const auto &cv : clauseVals) {
|
||||
ClauseVal cval{cv};
|
||||
if (!cval.isUserVisible())
|
||||
continue;
|
||||
|
||||
const auto name = cval.getFormattedName();
|
||||
std::string cvDef{(enumName + llvm::Twine(name)).str()};
|
||||
os << "def " << cvDef << " : StrEnumAttrCase<\"" << name << "\">;\n";
|
||||
cvDefs.push_back(cvDef);
|
||||
}
|
||||
|
||||
os << "def " << enumName << ": StrEnumAttr<\n";
|
||||
os << " \"Clause" << enumName << "\",\n";
|
||||
os << " \"" << enumName << " Clause\",\n";
|
||||
os << " [";
|
||||
for (unsigned int i = 0; i < cvDefs.size(); i++) {
|
||||
os << cvDefs[i];
|
||||
if (i != cvDefs.size() - 1)
|
||||
os << ",";
|
||||
}
|
||||
os << "]> {\n";
|
||||
os << " let cppNamespace = \"::mlir::omp\";\n";
|
||||
os << "}\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Registers the generator to mlir-tblgen.
|
||||
static mlir::GenRegistration
|
||||
genDirectiveDecls("gen-directive-decl",
|
||||
"Generate declarations for directives (OpenMP etc.)",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitDecls(records, os);
|
||||
});
|
Loading…
Reference in New Issue