diff --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h index 893167e33105..c030d51305e4 100644 --- a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h @@ -31,6 +31,7 @@ namespace LLVM { namespace detail { class TypeToLLVMIRTranslatorImpl; +class TypeFromLLVMIRTranslatorImpl; } // namespace detail /// Utility class to translate MLIR LLVM dialect types to LLVM IR. Stores the @@ -55,6 +56,22 @@ private: std::unique_ptr impl; }; +/// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores +/// the translation state, in particular any identified structure types that are +/// reused across translations. +class TypeFromLLVMIRTranslator { +public: + TypeFromLLVMIRTranslator(MLIRContext &context); + ~TypeFromLLVMIRTranslator(); + + /// Translates the given LLVM IR type to the MLIR LLVM dialect. + Type translateType(llvm::Type *type); + +private: + /// Private implementation. + std::unique_ptr impl; +}; + } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 4a1653a39b63..3c272e0d2312 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Target/LLVMIR/Import.h" +#include "mlir/Target/LLVMIR/TypeTranslation.h" #include "mlir/Translation.h" #include "llvm/ADT/TypeSwitch.h" @@ -45,167 +46,6 @@ static std::string diag(llvm::Value &v) { return os.str(); } -namespace mlir { -namespace LLVM { -namespace detail { -/// Support for translating LLVM IR types to MLIR LLVM dialect types. -class TypeFromLLVMIRTranslatorImpl { -public: - /// Constructs a class creating types in the given MLIR context. - TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} - - /// Translates the given type. - Type translateType(llvm::Type *type) { - if (knownTranslations.count(type)) - return knownTranslations.lookup(type); - - Type translated = - llvm::TypeSwitch(type) - .Case( - [this](auto *type) { return this->translate(type); }) - .Default([this](llvm::Type *type) { - return translatePrimitiveType(type); - }); - knownTranslations.try_emplace(type, translated); - return translated; - } - -private: - /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, - /// type. - Type translatePrimitiveType(llvm::Type *type) { - if (type->isVoidTy()) - return LLVM::LLVMVoidType::get(&context); - if (type->isHalfTy()) - return Float16Type::get(&context); - if (type->isBFloatTy()) - return BFloat16Type::get(&context); - if (type->isFloatTy()) - return Float32Type::get(&context); - if (type->isDoubleTy()) - return Float64Type::get(&context); - if (type->isFP128Ty()) - return Float128Type::get(&context); - if (type->isX86_FP80Ty()) - return Float80Type::get(&context); - if (type->isPPC_FP128Ty()) - return LLVM::LLVMPPCFP128Type::get(&context); - if (type->isX86_MMXTy()) - return LLVM::LLVMX86MMXType::get(&context); - if (type->isLabelTy()) - return LLVM::LLVMLabelType::get(&context); - if (type->isMetadataTy()) - return LLVM::LLVMMetadataType::get(&context); - llvm_unreachable("not a primitive type"); - } - - /// Translates the given array type. - Type translate(llvm::ArrayType *type) { - return LLVM::LLVMArrayType::get(translateType(type->getElementType()), - type->getNumElements()); - } - - /// Translates the given function type. - Type translate(llvm::FunctionType *type) { - SmallVector paramTypes; - translateTypes(type->params(), paramTypes); - return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), - paramTypes, type->isVarArg()); - } - - /// Translates the given integer type. - Type translate(llvm::IntegerType *type) { - return IntegerType::get(&context, type->getBitWidth()); - } - - /// Translates the given pointer type. - Type translate(llvm::PointerType *type) { - return LLVM::LLVMPointerType::get(translateType(type->getElementType()), - type->getAddressSpace()); - } - - /// Translates the given structure type. - Type translate(llvm::StructType *type) { - SmallVector subtypes; - if (type->isLiteral()) { - translateTypes(type->subtypes(), subtypes); - return LLVM::LLVMStructType::getLiteral(&context, subtypes, - type->isPacked()); - } - - if (type->isOpaque()) - return LLVM::LLVMStructType::getOpaque(type->getName(), &context); - - LLVM::LLVMStructType translated = - LLVM::LLVMStructType::getIdentified(&context, type->getName()); - knownTranslations.try_emplace(type, translated); - translateTypes(type->subtypes(), subtypes); - LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); - assert(succeeded(bodySet) && - "could not set the body of an identified struct"); - (void)bodySet; - return translated; - } - - /// Translates the given fixed-vector type. - Type translate(llvm::FixedVectorType *type) { - return LLVM::getFixedVectorType(translateType(type->getElementType()), - type->getNumElements()); - } - - /// Translates the given scalable-vector type. - Type translate(llvm::ScalableVectorType *type) { - return LLVM::LLVMScalableVectorType::get( - translateType(type->getElementType()), type->getMinNumElements()); - } - - /// Translates a list of types. - void translateTypes(ArrayRef types, - SmallVectorImpl &result) { - result.reserve(result.size() + types.size()); - for (llvm::Type *type : types) - result.push_back(translateType(type)); - } - - /// Map of known translations. Serves as a cache and as recursion stopper for - /// translating recursive structs. - llvm::DenseMap knownTranslations; - - /// The context in which MLIR types are created. - MLIRContext &context; -}; -} // end namespace detail - -/// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores -/// the translation state, in particular any identified structure types that are -/// reused across translations. -class TypeFromLLVMIRTranslator { -public: - TypeFromLLVMIRTranslator(MLIRContext &context); - ~TypeFromLLVMIRTranslator(); - - /// Translates the given LLVM IR type to the MLIR LLVM dialect. - Type translateType(llvm::Type *type); - -private: - /// Private implementation. - std::unique_ptr impl; -}; - -} // end namespace LLVM -} // end namespace mlir - -LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context) - : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {} - -LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {} - -Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { - return impl->translateType(type); -} - // Handles importing globals and functions from an LLVM module. namespace { class Importer { diff --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp index f40a6d943cbb..27b032f98bc7 100644 --- a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp @@ -170,6 +170,136 @@ private: /// type instead of creating a new type. llvm::DenseMap knownTranslations; }; + +/// Support for translating LLVM IR types to MLIR LLVM dialect types. +class TypeFromLLVMIRTranslatorImpl { +public: + /// Constructs a class creating types in the given MLIR context. + TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} + + /// Translates the given type. + Type translateType(llvm::Type *type) { + if (knownTranslations.count(type)) + return knownTranslations.lookup(type); + + Type translated = + llvm::TypeSwitch(type) + .Case( + [this](auto *type) { return this->translate(type); }) + .Default([this](llvm::Type *type) { + return translatePrimitiveType(type); + }); + knownTranslations.try_emplace(type, translated); + return translated; + } + +private: + /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, + /// type. + Type translatePrimitiveType(llvm::Type *type) { + if (type->isVoidTy()) + return LLVM::LLVMVoidType::get(&context); + if (type->isHalfTy()) + return Float16Type::get(&context); + if (type->isBFloatTy()) + return BFloat16Type::get(&context); + if (type->isFloatTy()) + return Float32Type::get(&context); + if (type->isDoubleTy()) + return Float64Type::get(&context); + if (type->isFP128Ty()) + return Float128Type::get(&context); + if (type->isX86_FP80Ty()) + return Float80Type::get(&context); + if (type->isPPC_FP128Ty()) + return LLVM::LLVMPPCFP128Type::get(&context); + if (type->isX86_MMXTy()) + return LLVM::LLVMX86MMXType::get(&context); + if (type->isLabelTy()) + return LLVM::LLVMLabelType::get(&context); + if (type->isMetadataTy()) + return LLVM::LLVMMetadataType::get(&context); + llvm_unreachable("not a primitive type"); + } + + /// Translates the given array type. + Type translate(llvm::ArrayType *type) { + return LLVM::LLVMArrayType::get(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given function type. + Type translate(llvm::FunctionType *type) { + SmallVector paramTypes; + translateTypes(type->params(), paramTypes); + return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), + paramTypes, type->isVarArg()); + } + + /// Translates the given integer type. + Type translate(llvm::IntegerType *type) { + return IntegerType::get(&context, type->getBitWidth()); + } + + /// Translates the given pointer type. + Type translate(llvm::PointerType *type) { + return LLVM::LLVMPointerType::get(translateType(type->getElementType()), + type->getAddressSpace()); + } + + /// Translates the given structure type. + Type translate(llvm::StructType *type) { + SmallVector subtypes; + if (type->isLiteral()) { + translateTypes(type->subtypes(), subtypes); + return LLVM::LLVMStructType::getLiteral(&context, subtypes, + type->isPacked()); + } + + if (type->isOpaque()) + return LLVM::LLVMStructType::getOpaque(type->getName(), &context); + + LLVM::LLVMStructType translated = + LLVM::LLVMStructType::getIdentified(&context, type->getName()); + knownTranslations.try_emplace(type, translated); + translateTypes(type->subtypes(), subtypes); + LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); + assert(succeeded(bodySet) && + "could not set the body of an identified struct"); + (void)bodySet; + return translated; + } + + /// Translates the given fixed-vector type. + Type translate(llvm::FixedVectorType *type) { + return LLVM::getFixedVectorType(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given scalable-vector type. + Type translate(llvm::ScalableVectorType *type) { + return LLVM::LLVMScalableVectorType::get( + translateType(type->getElementType()), type->getMinNumElements()); + } + + /// Translates a list of types. + void translateTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (llvm::Type *type : types) + result.push_back(translateType(type)); + } + + /// Map of known translations. Serves as a cache and as recursion stopper for + /// translating recursive structs. + llvm::DenseMap knownTranslations; + + /// The context in which MLIR types are created. + MLIRContext &context; +}; + } // end namespace detail } // end namespace LLVM } // end namespace mlir @@ -187,3 +317,12 @@ unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment( Type type, const llvm::DataLayout &layout) { return layout.getPrefTypeAlignment(translateType(type)); } + +LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context) + : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {} + +LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {} + +Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { + return impl->translateType(type); +}