forked from OSchip/llvm-project
[MLIR][LLVM] Expose type translator from LLVM to MLIR Type
This commit moves the type translator from LLVM to MLIR to a public header for use by external projects or other code Differential Revision: https://reviews.llvm.org/D104726
This commit is contained in:
parent
45d5373511
commit
5616a79398
|
@ -31,6 +31,7 @@ namespace LLVM {
|
|||
|
||||
namespace detail {
|
||||
class TypeToLLVMIRTranslatorImpl;
|
||||
class TypeFromLLVMIRTranslatorImpl;
|
||||
} // namespace detail
|
||||
|
||||
/// Utility class to translate MLIR LLVM dialect types to LLVM IR. Stores the
|
||||
|
@ -55,6 +56,22 @@ private:
|
|||
std::unique_ptr<detail::TypeToLLVMIRTranslatorImpl> impl;
|
||||
};
|
||||
|
||||
/// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
|
||||
/// the translation state, in particular any identified structure types that are
|
||||
/// reused across translations.
|
||||
class TypeFromLLVMIRTranslator {
|
||||
public:
|
||||
TypeFromLLVMIRTranslator(MLIRContext &context);
|
||||
~TypeFromLLVMIRTranslator();
|
||||
|
||||
/// Translates the given LLVM IR type to the MLIR LLVM dialect.
|
||||
Type translateType(llvm::Type *type);
|
||||
|
||||
private:
|
||||
/// Private implementation.
|
||||
std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
|
||||
};
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Target/LLVMIR/Import.h"
|
||||
#include "mlir/Target/LLVMIR/TypeTranslation.h"
|
||||
#include "mlir/Translation.h"
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
@ -45,167 +46,6 @@ static std::string diag(llvm::Value &v) {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
namespace detail {
|
||||
/// Support for translating LLVM IR types to MLIR LLVM dialect types.
|
||||
class TypeFromLLVMIRTranslatorImpl {
|
||||
public:
|
||||
/// Constructs a class creating types in the given MLIR context.
|
||||
TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
|
||||
|
||||
/// Translates the given type.
|
||||
Type translateType(llvm::Type *type) {
|
||||
if (knownTranslations.count(type))
|
||||
return knownTranslations.lookup(type);
|
||||
|
||||
Type translated =
|
||||
llvm::TypeSwitch<llvm::Type *, Type>(type)
|
||||
.Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
|
||||
llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
|
||||
llvm::ScalableVectorType>(
|
||||
[this](auto *type) { return this->translate(type); })
|
||||
.Default([this](llvm::Type *type) {
|
||||
return translatePrimitiveType(type);
|
||||
});
|
||||
knownTranslations.try_emplace(type, translated);
|
||||
return translated;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
|
||||
/// type.
|
||||
Type translatePrimitiveType(llvm::Type *type) {
|
||||
if (type->isVoidTy())
|
||||
return LLVM::LLVMVoidType::get(&context);
|
||||
if (type->isHalfTy())
|
||||
return Float16Type::get(&context);
|
||||
if (type->isBFloatTy())
|
||||
return BFloat16Type::get(&context);
|
||||
if (type->isFloatTy())
|
||||
return Float32Type::get(&context);
|
||||
if (type->isDoubleTy())
|
||||
return Float64Type::get(&context);
|
||||
if (type->isFP128Ty())
|
||||
return Float128Type::get(&context);
|
||||
if (type->isX86_FP80Ty())
|
||||
return Float80Type::get(&context);
|
||||
if (type->isPPC_FP128Ty())
|
||||
return LLVM::LLVMPPCFP128Type::get(&context);
|
||||
if (type->isX86_MMXTy())
|
||||
return LLVM::LLVMX86MMXType::get(&context);
|
||||
if (type->isLabelTy())
|
||||
return LLVM::LLVMLabelType::get(&context);
|
||||
if (type->isMetadataTy())
|
||||
return LLVM::LLVMMetadataType::get(&context);
|
||||
llvm_unreachable("not a primitive type");
|
||||
}
|
||||
|
||||
/// Translates the given array type.
|
||||
Type translate(llvm::ArrayType *type) {
|
||||
return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
|
||||
type->getNumElements());
|
||||
}
|
||||
|
||||
/// Translates the given function type.
|
||||
Type translate(llvm::FunctionType *type) {
|
||||
SmallVector<Type, 8> paramTypes;
|
||||
translateTypes(type->params(), paramTypes);
|
||||
return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
|
||||
paramTypes, type->isVarArg());
|
||||
}
|
||||
|
||||
/// Translates the given integer type.
|
||||
Type translate(llvm::IntegerType *type) {
|
||||
return IntegerType::get(&context, type->getBitWidth());
|
||||
}
|
||||
|
||||
/// Translates the given pointer type.
|
||||
Type translate(llvm::PointerType *type) {
|
||||
return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
|
||||
type->getAddressSpace());
|
||||
}
|
||||
|
||||
/// Translates the given structure type.
|
||||
Type translate(llvm::StructType *type) {
|
||||
SmallVector<Type, 8> subtypes;
|
||||
if (type->isLiteral()) {
|
||||
translateTypes(type->subtypes(), subtypes);
|
||||
return LLVM::LLVMStructType::getLiteral(&context, subtypes,
|
||||
type->isPacked());
|
||||
}
|
||||
|
||||
if (type->isOpaque())
|
||||
return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
|
||||
|
||||
LLVM::LLVMStructType translated =
|
||||
LLVM::LLVMStructType::getIdentified(&context, type->getName());
|
||||
knownTranslations.try_emplace(type, translated);
|
||||
translateTypes(type->subtypes(), subtypes);
|
||||
LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
|
||||
assert(succeeded(bodySet) &&
|
||||
"could not set the body of an identified struct");
|
||||
(void)bodySet;
|
||||
return translated;
|
||||
}
|
||||
|
||||
/// Translates the given fixed-vector type.
|
||||
Type translate(llvm::FixedVectorType *type) {
|
||||
return LLVM::getFixedVectorType(translateType(type->getElementType()),
|
||||
type->getNumElements());
|
||||
}
|
||||
|
||||
/// Translates the given scalable-vector type.
|
||||
Type translate(llvm::ScalableVectorType *type) {
|
||||
return LLVM::LLVMScalableVectorType::get(
|
||||
translateType(type->getElementType()), type->getMinNumElements());
|
||||
}
|
||||
|
||||
/// Translates a list of types.
|
||||
void translateTypes(ArrayRef<llvm::Type *> types,
|
||||
SmallVectorImpl<Type> &result) {
|
||||
result.reserve(result.size() + types.size());
|
||||
for (llvm::Type *type : types)
|
||||
result.push_back(translateType(type));
|
||||
}
|
||||
|
||||
/// Map of known translations. Serves as a cache and as recursion stopper for
|
||||
/// translating recursive structs.
|
||||
llvm::DenseMap<llvm::Type *, Type> knownTranslations;
|
||||
|
||||
/// The context in which MLIR types are created.
|
||||
MLIRContext &context;
|
||||
};
|
||||
} // end namespace detail
|
||||
|
||||
/// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
|
||||
/// the translation state, in particular any identified structure types that are
|
||||
/// reused across translations.
|
||||
class TypeFromLLVMIRTranslator {
|
||||
public:
|
||||
TypeFromLLVMIRTranslator(MLIRContext &context);
|
||||
~TypeFromLLVMIRTranslator();
|
||||
|
||||
/// Translates the given LLVM IR type to the MLIR LLVM dialect.
|
||||
Type translateType(llvm::Type *type);
|
||||
|
||||
private:
|
||||
/// Private implementation.
|
||||
std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
|
||||
};
|
||||
|
||||
} // end namespace LLVM
|
||||
} // end namespace mlir
|
||||
|
||||
LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
|
||||
: impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
|
||||
|
||||
LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
|
||||
|
||||
Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
|
||||
return impl->translateType(type);
|
||||
}
|
||||
|
||||
// Handles importing globals and functions from an LLVM module.
|
||||
namespace {
|
||||
class Importer {
|
||||
|
|
|
@ -170,6 +170,136 @@ private:
|
|||
/// type instead of creating a new type.
|
||||
llvm::DenseMap<Type, llvm::Type *> knownTranslations;
|
||||
};
|
||||
|
||||
/// Support for translating LLVM IR types to MLIR LLVM dialect types.
|
||||
class TypeFromLLVMIRTranslatorImpl {
|
||||
public:
|
||||
/// Constructs a class creating types in the given MLIR context.
|
||||
TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
|
||||
|
||||
/// Translates the given type.
|
||||
Type translateType(llvm::Type *type) {
|
||||
if (knownTranslations.count(type))
|
||||
return knownTranslations.lookup(type);
|
||||
|
||||
Type translated =
|
||||
llvm::TypeSwitch<llvm::Type *, Type>(type)
|
||||
.Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
|
||||
llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
|
||||
llvm::ScalableVectorType>(
|
||||
[this](auto *type) { return this->translate(type); })
|
||||
.Default([this](llvm::Type *type) {
|
||||
return translatePrimitiveType(type);
|
||||
});
|
||||
knownTranslations.try_emplace(type, translated);
|
||||
return translated;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
|
||||
/// type.
|
||||
Type translatePrimitiveType(llvm::Type *type) {
|
||||
if (type->isVoidTy())
|
||||
return LLVM::LLVMVoidType::get(&context);
|
||||
if (type->isHalfTy())
|
||||
return Float16Type::get(&context);
|
||||
if (type->isBFloatTy())
|
||||
return BFloat16Type::get(&context);
|
||||
if (type->isFloatTy())
|
||||
return Float32Type::get(&context);
|
||||
if (type->isDoubleTy())
|
||||
return Float64Type::get(&context);
|
||||
if (type->isFP128Ty())
|
||||
return Float128Type::get(&context);
|
||||
if (type->isX86_FP80Ty())
|
||||
return Float80Type::get(&context);
|
||||
if (type->isPPC_FP128Ty())
|
||||
return LLVM::LLVMPPCFP128Type::get(&context);
|
||||
if (type->isX86_MMXTy())
|
||||
return LLVM::LLVMX86MMXType::get(&context);
|
||||
if (type->isLabelTy())
|
||||
return LLVM::LLVMLabelType::get(&context);
|
||||
if (type->isMetadataTy())
|
||||
return LLVM::LLVMMetadataType::get(&context);
|
||||
llvm_unreachable("not a primitive type");
|
||||
}
|
||||
|
||||
/// Translates the given array type.
|
||||
Type translate(llvm::ArrayType *type) {
|
||||
return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
|
||||
type->getNumElements());
|
||||
}
|
||||
|
||||
/// Translates the given function type.
|
||||
Type translate(llvm::FunctionType *type) {
|
||||
SmallVector<Type, 8> paramTypes;
|
||||
translateTypes(type->params(), paramTypes);
|
||||
return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
|
||||
paramTypes, type->isVarArg());
|
||||
}
|
||||
|
||||
/// Translates the given integer type.
|
||||
Type translate(llvm::IntegerType *type) {
|
||||
return IntegerType::get(&context, type->getBitWidth());
|
||||
}
|
||||
|
||||
/// Translates the given pointer type.
|
||||
Type translate(llvm::PointerType *type) {
|
||||
return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
|
||||
type->getAddressSpace());
|
||||
}
|
||||
|
||||
/// Translates the given structure type.
|
||||
Type translate(llvm::StructType *type) {
|
||||
SmallVector<Type, 8> subtypes;
|
||||
if (type->isLiteral()) {
|
||||
translateTypes(type->subtypes(), subtypes);
|
||||
return LLVM::LLVMStructType::getLiteral(&context, subtypes,
|
||||
type->isPacked());
|
||||
}
|
||||
|
||||
if (type->isOpaque())
|
||||
return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
|
||||
|
||||
LLVM::LLVMStructType translated =
|
||||
LLVM::LLVMStructType::getIdentified(&context, type->getName());
|
||||
knownTranslations.try_emplace(type, translated);
|
||||
translateTypes(type->subtypes(), subtypes);
|
||||
LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
|
||||
assert(succeeded(bodySet) &&
|
||||
"could not set the body of an identified struct");
|
||||
(void)bodySet;
|
||||
return translated;
|
||||
}
|
||||
|
||||
/// Translates the given fixed-vector type.
|
||||
Type translate(llvm::FixedVectorType *type) {
|
||||
return LLVM::getFixedVectorType(translateType(type->getElementType()),
|
||||
type->getNumElements());
|
||||
}
|
||||
|
||||
/// Translates the given scalable-vector type.
|
||||
Type translate(llvm::ScalableVectorType *type) {
|
||||
return LLVM::LLVMScalableVectorType::get(
|
||||
translateType(type->getElementType()), type->getMinNumElements());
|
||||
}
|
||||
|
||||
/// Translates a list of types.
|
||||
void translateTypes(ArrayRef<llvm::Type *> types,
|
||||
SmallVectorImpl<Type> &result) {
|
||||
result.reserve(result.size() + types.size());
|
||||
for (llvm::Type *type : types)
|
||||
result.push_back(translateType(type));
|
||||
}
|
||||
|
||||
/// Map of known translations. Serves as a cache and as recursion stopper for
|
||||
/// translating recursive structs.
|
||||
llvm::DenseMap<llvm::Type *, Type> knownTranslations;
|
||||
|
||||
/// The context in which MLIR types are created.
|
||||
MLIRContext &context;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
} // end namespace LLVM
|
||||
} // end namespace mlir
|
||||
|
@ -187,3 +317,12 @@ unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment(
|
|||
Type type, const llvm::DataLayout &layout) {
|
||||
return layout.getPrefTypeAlignment(translateType(type));
|
||||
}
|
||||
|
||||
LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
|
||||
: impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
|
||||
|
||||
LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
|
||||
|
||||
Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
|
||||
return impl->translateType(type);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue