[mlir][LLVM] Add support for Calling Convention in LLVMFuncOp

This patch adds support for Calling Convention attribute in LLVM
dialect, including enums, custom syntax and import from LLVM IR.
Additionally fix import of dso_local attribute.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D126161
This commit is contained in:
Alexander Batashev 2022-05-27 09:23:27 +03:00
parent a84026821b
commit 0252357b3e
11 changed files with 333 additions and 12 deletions

View File

@ -35,6 +35,15 @@ def LinkageAttr : LLVM_Attr<"Linkage"> {
let hasCustomAssemblyFormat = 1;
}
// Attribute definition for the LLVM Linkage enum.
def CConvAttr : LLVM_Attr<"CConv"> {
let mnemonic = "cconv";
let parameters = (ins
"CConv":$CConv
);
let hasCustomAssemblyFormat = 1;
}
def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
let mnemonic = "loopopts";

View File

@ -39,6 +39,7 @@ namespace LLVM {
// attribute definition itself.
// TODO: this shouldn't be needed after we unify the attribute generation, i.e.
// --gen-attr-* and --gen-attrdef-*.
using cconv::CConv;
using linkage::Linkage;
} // namespace LLVM
} // namespace mlir

View File

@ -234,6 +234,14 @@ class LLVM_EnumAttr<string name, string llvmName, string description,
string llvmClassName = llvmName;
}
// LLVM_CEnumAttr is functionally identical to LLVM_EnumAttr, but to be used for
// non-class enums.
class LLVM_CEnumAttr<string name, string llvmNS, string description,
list<LLVM_EnumAttrCase> cases> :
I64EnumAttr<name, description, cases> {
string llvmClassName = llvmNS;
}
// For every value in the list, substitutes the value in the place of "$0" in
// "pattern" and stores the list of strings as "lst".
class ListIntSubst<string pattern, list<int> values> {

View File

@ -67,6 +67,134 @@ def LoopOptionCase : I32EnumAttr<
let cppNamespace = "::mlir::LLVM";
}
// These values must match llvm::CallingConv ones.
// See https://llvm.org/doxygen/namespacellvm_1_1CallingConv.html for full list
// of supported calling conventions.
def CConvC : LLVM_EnumAttrCase<"C", "ccc", "C", 0>;
def CConvFast : LLVM_EnumAttrCase<"Fast", "fastcc", "Fast", 8>;
def CConvCold : LLVM_EnumAttrCase<"Cold", "coldcc", "Cold", 9>;
def CConvGHC : LLVM_EnumAttrCase<"GHC", "cc_10", "GHC", 10>;
def CConvHiPE : LLVM_EnumAttrCase<"HiPE", "cc_11", "HiPE", 11>;
def CConvWebKitJS : LLVM_EnumAttrCase<"WebKit_JS", "webkit_jscc",
"WebKit_JS", 12>;
def CConvAnyReg : LLVM_EnumAttrCase<"AnyReg", "anyregcc", "AnyReg", 13>;
def CConvPreserveMost : LLVM_EnumAttrCase<"PreserveMost", "preserve_mostcc",
"PreserveMost", 14>;
def CConvPreserveAll : LLVM_EnumAttrCase<"PreserveAll", "preserve_allcc",
"PreserveAll", 15>;
def CConvSwift : LLVM_EnumAttrCase<"Swift", "swiftcc", "Swift", 16>;
def CConvCXXFastTLS : LLVM_EnumAttrCase<"CXX_FAST_TLS", "cxx_fast_tlscc",
"CXX_FAST_TLS", 17>;
def CConvTail : LLVM_EnumAttrCase<"Tail", "tailcc", "Tail", 18>;
def CConvCFGuard_Check : LLVM_EnumAttrCase<"CFGuard_Check",
"cfguard_checkcc",
"CFGuard_Check", 19>;
def CConvSwiftTail : LLVM_EnumAttrCase<"SwiftTail", "swifttailcc",
"SwiftTail", 20>;
def CConvX86_StdCall : LLVM_EnumAttrCase<"X86_StdCall", "x86_stdcallcc",
"X86_StdCall", 64>;
def CConvX86_FastCall : LLVM_EnumAttrCase<"X86_FastCall", "x86_fastcallcc",
"X86_FastCall", 65>;
def CConvARM_APCS : LLVM_EnumAttrCase<"ARM_APCS", "arm_apcscc", "ARM_APCS", 66>;
def CConvARM_AAPCS : LLVM_EnumAttrCase<"ARM_AAPCS", "arm_aapcscc", "ARM_AAPCS",
67>;
def CConvARM_AAPCS_VFP : LLVM_EnumAttrCase<"ARM_AAPCS_VFP", "arm_aapcs_vfpcc",
"ARM_AAPCS_VFP", 68>;
def CConvMSP430_INTR : LLVM_EnumAttrCase<"MSP430_INTR", "msp430_intrcc",
"MSP430_INTR", 69>;
def CConvX86_ThisCall : LLVM_EnumAttrCase<"X86_ThisCall", "x86_thiscallcc",
"X86_ThisCall", 70>;
def CConvPTX_Kernel : LLVM_EnumAttrCase<"PTX_Kernel", "ptx_kernelcc",
"PTX_Kernel", 71>;
def CConvPTX_Device : LLVM_EnumAttrCase<"PTX_Device", "ptx_devicecc",
"PTX_Device", 72>;
def CConvSPIR_FUNC : LLVM_EnumAttrCase<"SPIR_FUNC", "spir_funccc",
"SPIR_FUNC", 75>;
def CConvSPIR_KERNEL : LLVM_EnumAttrCase<"SPIR_KERNEL", "spir_kernelcc",
"SPIR_KERNEL", 76>;
def CConvIntel_OCL_BI : LLVM_EnumAttrCase<"Intel_OCL_BI", "intel_ocl_bicc",
"Intel_OCL_BI", 77>;
def CConvX86_64_SysV : LLVM_EnumAttrCase<"X86_64_SysV", "x86_64_sysvcc",
"X86_64_SysV", 78>;
def CConvWin64 : LLVM_EnumAttrCase<"Win64", "win64cc", "Win64", 79>;
def CConvX86_VectorCall : LLVM_EnumAttrCase<"X86_VectorCall",
"x86_vectorcallcc",
"X86_VectorCall", 80>;
def CConvHHVM : LLVM_EnumAttrCase<"HHVM", "hhvmcc", "HHVM", 81>;
def CConvHHVM_C : LLVM_EnumAttrCase<"HHVM_C", "hhvm_ccc", "HHVM_C", 82>;
def CConvX86_INTR : LLVM_EnumAttrCase<"X86_INTR", "x86_intrcc", "X86_INTR", 83>;
def CConvAVR_INTR : LLVM_EnumAttrCase<"AVR_INTR", "avr_intrcc", "AVR_INTR", 84>;
def CConvAVR_SIGNAL : LLVM_EnumAttrCase<"AVR_SIGNAL", "avr_signalcc",
"AVR_SIGNAL", 85>;
def CConvAVR_BUILTIN : LLVM_EnumAttrCase<"AVR_BUILTIN", "avr_builtincc",
"AVR_BUILTIN", 86>;
def CConvAMDGPU_VS : LLVM_EnumAttrCase<"AMDGPU_VS", "amdgpu_vscc", "AMDGPU_VS",
87>;
def CConvAMDGPU_GS : LLVM_EnumAttrCase<"AMDGPU_GS", "amdgpu_gscc", "AMDGPU_GS",
88>;
def CConvAMDGPU_PS : LLVM_EnumAttrCase<"AMDGPU_PS", "amdgpu_pscc", "AMDGPU_PS",
89>;
def CConvAMDGPU_CS : LLVM_EnumAttrCase<"AMDGPU_CS", "amdgpu_cscc", "AMDGPU_CS",
90>;
def CConvAMDGPU_KERNEL : LLVM_EnumAttrCase<"AMDGPU_KERNEL", "amdgpu_kernelcc",
"AMDGPU_KERNEL", 91>;
def CConvX86_RegCall : LLVM_EnumAttrCase<"X86_RegCall", "x86_regcallcc",
"X86_RegCall", 92>;
def CConvAMDGPU_HS : LLVM_EnumAttrCase<"AMDGPU_HS", "amdgpu_hscc", "AMDGPU_HS",
93>;
def CConvMSP430_BUILTIN : LLVM_EnumAttrCase<"MSP430_BUILTIN",
"msp430_builtincc",
"MSP430_BUILTIN", 94>;
def CConvAMDGPU_LS : LLVM_EnumAttrCase<"AMDGPU_LS", "amdgpu_lscc", "AMDGPU_LS",
95>;
def CConvAMDGPU_ES : LLVM_EnumAttrCase<"AMDGPU_ES", "amdgpu_escc", "AMDGPU_ES",
96>;
def CConvAArch64_VectorCall : LLVM_EnumAttrCase<"AArch64_VectorCall",
"aarch64_vectorcallcc",
"AArch64_VectorCall", 97>;
def CConvAArch64_SVE_VectorCall : LLVM_EnumAttrCase<"AArch64_SVE_VectorCall",
"aarch64_sve_vectorcallcc",
"AArch64_SVE_VectorCall",
98>;
def CConvWASM_EmscriptenInvoke : LLVM_EnumAttrCase<"WASM_EmscriptenInvoke",
"wasm_emscripten_invokecc",
"WASM_EmscriptenInvoke", 99>;
def CConvAMDGPU_Gfx : LLVM_EnumAttrCase<"AMDGPU_Gfx", "amdgpu_gfxcc",
"AMDGPU_Gfx", 100>;
def CConvM68k_INTR : LLVM_EnumAttrCase<"M68k_INTR", "m68k_intrcc", "M68k_INTR",
101>;
def CConvEnum : LLVM_CEnumAttr<
"CConv",
"::llvm::CallingConv",
"Calling Conventions",
[CConvC, CConvFast, CConvCold, CConvGHC, CConvHiPE, CConvWebKitJS,
CConvAnyReg, CConvPreserveMost, CConvPreserveAll, CConvSwift,
CConvCXXFastTLS, CConvTail, CConvCFGuard_Check, CConvSwiftTail,
CConvX86_StdCall, CConvX86_FastCall, CConvARM_APCS,
CConvARM_AAPCS, CConvARM_AAPCS_VFP, CConvMSP430_INTR, CConvX86_ThisCall,
CConvPTX_Kernel, CConvPTX_Device, CConvSPIR_FUNC, CConvSPIR_KERNEL,
CConvIntel_OCL_BI, CConvX86_64_SysV, CConvWin64, CConvX86_VectorCall,
CConvHHVM, CConvHHVM_C, CConvX86_INTR, CConvAVR_INTR, CConvAVR_BUILTIN,
CConvAMDGPU_VS, CConvAMDGPU_GS, CConvAMDGPU_CS, CConvAMDGPU_KERNEL,
CConvX86_RegCall, CConvAMDGPU_HS, CConvMSP430_BUILTIN, CConvAMDGPU_LS,
CConvAMDGPU_ES, CConvAArch64_VectorCall, CConvAArch64_SVE_VectorCall,
CConvWASM_EmscriptenInvoke, CConvAMDGPU_Gfx, CConvM68k_INTR
]> {
let cppNamespace = "::mlir::LLVM::cconv";
}
def CConv : DialectAttr<
LLVM_Dialect,
CPred<"$_self.isa<::mlir::LLVM::CConvAttr>()">,
"LLVM Calling Convention specification"> {
let storageType = "::mlir::LLVM::CConvAttr";
let returnType = "::mlir::LLVM::cconv::CConv";
let convertFromStorage = "$_self.getCConv()";
let constBuilderCall =
"::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)";
}
class LLVM_Builder<string builder> {
string llvmBuilder = builder;
}
@ -1233,6 +1361,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
TypeAttrOf<LLVM_FunctionType>:$function_type,
DefaultValuedAttr<Linkage, "Linkage::External">:$linkage,
UnitAttr:$dso_local,
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
OptionalAttr<FlatSymbolRefAttr>:$personality,
OptionalAttr<StrAttr>:$garbageCollector,
OptionalAttr<ArrayAttr>:$passthrough
@ -1246,6 +1375,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OpBuilder<(ins "StringRef":$name, "Type":$type,
CArg<"Linkage", "Linkage::External">:$linkage,
CArg<"bool", "false">:$dsoLocal,
CArg<"CConv", "CConv::C">:$cconv,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)>
];

View File

@ -139,7 +139,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
/*cconv*/ LLVM::CConv::C, attributes);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
@ -206,7 +207,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
// Create the auxiliary function.
auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false,
/*cconv*/ LLVM::CConv::C, attributes);
builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
@ -345,7 +347,7 @@ protected:
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, attributes);
/*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,

View File

@ -68,7 +68,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C,
attributes);
{
// Insert operations that correspond to converted workgroup and private

View File

@ -37,6 +37,7 @@
using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::cconv::getMaxEnumValForCConv;
using mlir::LLVM::linkage::getMaxEnumValForLinkage;
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
@ -1821,6 +1822,7 @@ struct EnumTraits {};
REGISTER_ENUM_TYPE(Linkage);
REGISTER_ENUM_TYPE(UnnamedAddr);
REGISTER_ENUM_TYPE(CConv);
} // namespace
/// Parse an enum from the keyword, or default to the provided default value.
@ -2124,7 +2126,8 @@ Block *LLVMFuncOp::addEntryBlock() {
void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
StringRef name, Type type, LLVM::Linkage linkage,
bool dsoLocal, ArrayRef<NamedAttribute> attrs,
bool dsoLocal, CConv cconv,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
result.addRegion();
result.addAttribute(SymbolTable::getSymbolAttrName(),
@ -2133,6 +2136,8 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
TypeAttr::get(type));
result.addAttribute(getLinkageAttrName(result.name),
LinkageAttr::get(builder.getContext(), linkage));
result.addAttribute(getCConvAttrName(result.name),
CConvAttr::get(builder.getContext(), cconv));
result.attributes.append(attrs.begin(), attrs.end());
if (dsoLocal)
result.addAttribute("dso_local", builder.getUnitAttr());
@ -2185,7 +2190,8 @@ buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
// Parses an LLVM function.
//
// operation ::= `llvm.func` linkage? function-signature function-attributes?
// operation ::= `llvm.func` linkage? cconv? function-signature
// function-attributes?
// function-body
//
ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
@ -2196,6 +2202,12 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
parseOptionalLLVMKeyword<Linkage>(
parser, result, LLVM::Linkage::External)));
// Default to C Calling Convention if no keyword is provided.
result.addAttribute(
getCConvAttrName(result.name),
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
parser, result, LLVM::CConv::C)));
StringAttr nameAttr;
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<DictionaryAttr> resultAttrs;
@ -2239,6 +2251,9 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
p << ' ';
if (getLinkage() != LLVM::Linkage::External)
p << stringifyLinkage(getLinkage()) << ' ';
if (getCConv() != LLVM::CConv::C)
p << stringifyCConv(getCConv()) << ' ';
p.printSymbolName(getName());
LLVMFunctionType fnType = getFunctionType();
@ -2255,7 +2270,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionSignature(p, *this, argTypes,
isVarArg(), resTypes);
function_interface_impl::printFunctionAttributes(
p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()});
p, *this, argTypes.size(), resTypes.size(),
{getLinkageAttrName(), getCConvAttrName()});
// Print the body if this is not an external function.
Region &body = getBody();
@ -2645,7 +2661,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
void LLVMDialect::initialize() {
addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>();
addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
// clang-format off
addTypes<LLVMVoidType,
@ -2940,6 +2956,31 @@ Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
return LinkageAttr::get(parser.getContext(), linkage);
}
void CConvAttr::print(AsmPrinter &printer) const {
printer << "<";
if (static_cast<uint64_t>(getCConv()) <= cconv::getMaxEnumValForCConv())
printer << stringifyEnum(getCConv());
else
printer << "INVALID_cc_" << static_cast<uint64_t>(getCConv());
printer << ">";
}
Attribute CConvAttr::parse(AsmParser &parser, Type type) {
StringRef convName;
if (parser.parseLess() || parser.parseKeyword(&convName) ||
parser.parseGreater())
return {};
auto cconv = cconv::symbolizeCConv(convName);
if (!cconv) {
parser.emitError(parser.getNameLoc(), "unknown calling convention: ")
<< convName;
return {};
}
CConv cconvVal = *cconv;
return CConvAttr::get(parser.getContext(), cconvVal);
}
LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
: options(attr.getOptions().begin(), attr.getOptions().end()) {}

View File

@ -1139,10 +1139,13 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
if (!functionType)
return failure();
bool dsoLocal = f->hasLocalLinkage();
CConv cconv = convertCConvFromLLVM(f->getCallingConv());
b.setInsertionPoint(module.getBody(), getFuncInsertPt());
LLVMFuncOp fop =
b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
convertLinkageFromLLVM(f->getLinkage()));
LLVMFuncOp fop = b.create<LLVMFuncOp>(
UnknownLoc::get(context), f->getName(), functionType,
convertLinkageFromLLVM(f->getLinkage()), dsoLocal, cconv);
if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
fop->setAttr(b.getStringAttr("personality"), personality);

View File

@ -144,6 +144,21 @@ module {
-> (!llvm.struct<(i32)> {llvm.struct_attrs = [{llvm.noalias}]}) {
llvm.return %arg0 : !llvm.struct<(i32)>
}
// CHECK: llvm.func @cconv1
llvm.func ccc @cconv1() {
llvm.return
}
// CHECK: llvm.func weak @cconv2
llvm.func weak ccc @cconv2() {
llvm.return
}
// CHECK: llvm.func weak fastcc @cconv3
llvm.func weak fastcc @cconv3() {
llvm.return
}
}
// -----
@ -251,3 +266,18 @@ module {
// expected-error@+1 {{functions cannot have 'common' linkage}}
llvm.func common @common_linkage_func()
}
// -----
module {
// expected-error@+1 {{custom op 'llvm.func' expected valid '@'-identifier for symbol name}}
llvm.func cc_12 @unknown_calling_convention()
}
// -----
module {
// expected-error@+2 {{unknown calling convention: cc_12}}
"llvm.func"() ({
}) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
}

View File

@ -122,8 +122,13 @@ define internal void @func_internal() {
; CHECK: llvm.func @fe(i32) -> f32
declare float @fe(i32)
; CHECK: llvm.func internal spir_funccc @spir_func_internal()
define internal spir_func void @spir_func_internal() {
ret void
}
; FIXME: function attributes.
; CHECK-LABEL: llvm.func internal @f1(%arg0: i64) -> i32 {
; CHECK-LABEL: llvm.func internal @f1(%arg0: i64) -> i32 attributes {dso_local} {
; CHECK-DAG: %[[c2:[0-9]+]] = llvm.mlir.constant(2 : i32) : i32
; CHECK-DAG: %[[c42:[0-9]+]] = llvm.mlir.constant(42 : i32) : i32
; CHECK-DAG: %[[c1:[0-9]+]] = llvm.mlir.constant(true) : i1

View File

@ -210,6 +210,27 @@ public:
return cases;
}
};
// Wraper class around a Tablegen definition of a C-style LLVM enum attribute.
class LLVMCEnumAttr : public tblgen::EnumAttr {
public:
using tblgen::EnumAttr::EnumAttr;
// Returns the C++ enum name for the LLVM API.
StringRef getLLVMClassName() const {
return def->getValueAsString("llvmClassName");
}
// Returns all associated cases viewed as LLVM-specific enum cases.
std::vector<LLVMEnumAttrCase> getAllCases() const {
std::vector<LLVMEnumAttrCase> cases;
for (auto &c : tblgen::EnumAttr::getAllCases())
cases.emplace_back(c);
return cases;
}
};
} // namespace
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
@ -242,6 +263,37 @@ static void emitOneEnumToConversion(const llvm::Record *record,
os << "}\n\n";
}
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API C-style enumerant
static void emitOneCEnumToConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMCEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute to its LLVM counterpart.
os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t "
"convert{0}ToLLVM({1}::{0} value) {{\n",
cppClassName, cppNamespace);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
cppEnumerant);
os << formatv(" return static_cast<int64_t>({0}::{1});\n", llvmClass,
llvmEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
enumAttr.getEnumClassName());
os << "}\n\n";
}
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
// LLVM dialect enum attribute (Enum).
@ -272,6 +324,38 @@ static void emitOneEnumFromConversion(const llvm::Record *record,
os << "}\n\n";
}
// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
// containing switch-based logic to convert from the LLVM API C-style enumerant
// to MLIR LLVM dialect enum attribute (Enum).
static void emitOneCEnumFromConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMCEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
os << formatv(
"inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t "
"value) {{\n",
cppNamespace, cppClassName, llvmClass);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case static_cast<int64_t>({0}::{1}):\n", llvmClass,
llvmEnumerant);
os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
cppEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");",
enumAttr.getLLVMClassName());
os << "}\n\n";
}
// Emits conversion functions between MLIR enum attribute case and corresponding
// LLVM API enumerants for all registered LLVM dialect enum attributes.
template <bool ConvertTo>
@ -283,6 +367,13 @@ static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
else
emitOneEnumFromConversion(def, os);
for (const auto *def :
recordKeeper.getAllDerivedDefinitions("LLVM_CEnumAttr"))
if (ConvertTo)
emitOneCEnumToConversion(def, os);
else
emitOneCEnumFromConversion(def, os);
return false;
}