Add thread-safe utilities to LLVMType to allow constructing llvm types in a multi-threaded environment. The LLVMContext is not thread-safe and directly constructing a raw llvm::Type can create situations where the LLVMContext is modified by multiple threads at the same time.

--

PiperOrigin-RevId: 249526233
This commit is contained in:
River Riddle 2019-05-22 14:56:07 -07:00 committed by Mehdi Amini
parent c33862b0ed
commit 5953d12b95
8 changed files with 343 additions and 223 deletions

View File

@ -62,22 +62,14 @@ Type linalg::convertLinalgType(Type t) {
// Simple conversions.
if (t.isa<IndexType>()) {
int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
return LLVM::LLVMType::get(context, integerTy);
}
if (auto intTy = t.dyn_cast<IntegerType>()) {
int width = intTy.getWidth();
auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
return LLVM::LLVMType::get(context, integerTy);
}
if (t.isF32()) {
auto *floatTy = llvm::Type::getFloatTy(dialect->getLLVMContext());
return LLVM::LLVMType::get(context, floatTy);
}
if (t.isF64()) {
auto *doubleTy = llvm::Type::getDoubleTy(dialect->getLLVMContext());
return LLVM::LLVMType::get(context, doubleTy);
return LLVM::LLVMType::getIntNTy(dialect, width);
}
if (auto intTy = t.dyn_cast<IntegerType>())
return LLVM::LLVMType::getIntNTy(dialect, intTy.getWidth());
if (t.isF32())
return LLVM::LLVMType::getFloatTy(dialect);
if (t.isF64())
return LLVM::LLVMType::getDoubleTy(dialect);
// Range descriptor contains the range bounds and the step as 64-bit integers.
//
@ -87,9 +79,8 @@ Type linalg::convertLinalgType(Type t) {
// int64_t step;
// };
if (auto rangeTy = t.dyn_cast<linalg::RangeType>()) {
auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
return LLVM::LLVMType::get(context, structTy);
auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
return LLVM::LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
}
// View descriptor contains the pointer to the data buffer, followed by a
@ -116,14 +107,12 @@ Type linalg::convertLinalgType(Type t) {
// int64_t strides[Rank];
// };
if (auto viewTy = t.dyn_cast<linalg::ViewType>()) {
auto *elemTy = linalg::convertLinalgType(viewTy.getElementType())
.cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo();
auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
auto *structTy = llvm::StructType::get(elemTy, int64Ty, arrayTy, arrayTy);
return LLVM::LLVMType::get(context, structTy);
auto elemTy = linalg::convertLinalgType(viewTy.getElementType())
.cast<LLVM::LLVMType>()
.getPointerTo();
auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
auto arrayTy = LLVM::LLVMType::getArrayTy(int64Ty, viewTy.getRank());
return LLVM::LLVMType::getStructTy(elemTy, int64Ty, arrayTy, arrayTy);
}
// All other types are kept as is.
@ -217,11 +206,9 @@ public:
if (type.hasStaticShape())
return memref;
auto elementTy = LLVM::LLVMType::get(
type.getContext(), linalg::convertLinalgType(type.getElementType())
.cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo());
auto elementTy = linalg::convertLinalgType(type.getElementType())
.cast<LLVM::LLVMType>()
.getPointerTo();
return intrinsics::extractvalue(elementTy, memref, pos(0));
};
@ -307,11 +294,9 @@ public:
auto sliceOp = cast<linalg::SliceOp>(op);
auto newViewDescriptorType =
linalg::convertLinalgType(sliceOp.getViewType());
auto elementType = rewriter.getType<LLVM::LLVMType>(
linalg::convertLinalgType(sliceOp.getElementType())
.cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo());
auto elementType = linalg::convertLinalgType(sliceOp.getElementType())
.cast<LLVM::LLVMType>()
.getPointerTo();
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {

View File

@ -67,11 +67,9 @@ public:
auto loadOp = cast<Op>(op);
auto elementType =
loadOp.getViewType().template cast<linalg::ViewType>().getElementType();
auto *llvmPtrType = linalg::convertLinalgType(elementType)
.template cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo();
elementType = rewriter.getType<LLVM::LLVMType>(llvmPtrType);
elementType = linalg::convertLinalgType(elementType)
.template cast<LLVM::LLVMType>()
.getPointerTo();
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {

View File

@ -210,15 +210,11 @@ private:
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
Builder builder(&module);
MLIRContext *context = module.getContext();
auto *llvmDialect =
auto *dialect =
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto &llvmModule = llvmDialect->getLLVMModule();
llvm::IRBuilder<> llvmBuilder(llvmModule.getContext());
auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32));
auto llvmI8PtrTy =
LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo());
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
// It should be variadic, but we don't support it fully just yet.

View File

@ -41,10 +41,12 @@ class LLVMContext;
namespace mlir {
namespace LLVM {
class LLVMDialect;
namespace detail {
struct LLVMTypeStorage;
}
struct LLVMDialectImpl;
} // namespace detail
class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
detail::LLVMTypeStorage> {
@ -57,9 +59,72 @@ public:
static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }
static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
LLVMDialect &getDialect();
llvm::Type *getUnderlyingType() const;
/// Array type utilities.
LLVMType getArrayElementType();
/// Pointer type utilities.
LLVMType getPointerTo(unsigned addrSpace = 0);
LLVMType getPointerElementTy();
/// Struct type utilities.
LLVMType getStructElementType(unsigned i);
/// Utilities used to generate floating point types.
static LLVMType getDoubleTy(LLVMDialect *dialect);
static LLVMType getFloatTy(LLVMDialect *dialect);
static LLVMType getHalfTy(LLVMDialect *dialect);
/// Utilities used to generate integer types.
static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
static LLVMType getInt1Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/1);
}
static LLVMType getInt8Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/8);
}
static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
return getInt8Ty(dialect).getPointerTo();
}
static LLVMType getInt16Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/16);
}
static LLVMType getInt32Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/32);
}
static LLVMType getInt64Ty(LLVMDialect *dialect) {
return getIntNTy(dialect, /*numBits=*/64);
}
/// Utilities used to generate other miscellaneous types.
static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements);
static LLVMType getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
bool isVarArg);
static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
return getFunctionTy(result, llvm::None, isVarArg);
}
static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
bool isPacked = false);
static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
return getStructTy(dialect, llvm::None, isPacked);
}
template <typename... Args>
static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
LLVMType>::type
getStructTy(LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
return getStructTy(&elt1.getDialect(), fields);
}
static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
static LLVMType getVoidTy(LLVMDialect *dialect);
private:
friend LLVMDialect;
/// Get an LLVM type with a pre-existing llvm type.
static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
};
///// Ops /////
@ -69,10 +134,11 @@ public:
class LLVMDialect : public Dialect {
public:
explicit LLVMDialect(MLIRContext *context);
~LLVMDialect();
static StringRef getDialectNamespace() { return "llvm"; }
llvm::LLVMContext &getLLVMContext() { return llvmContext; }
llvm::Module &getLLVMModule() { return module; }
llvm::LLVMContext &getLLVMContext();
llvm::Module &getLLVMModule();
/// Parse a type registered to this dialect.
Type parseType(StringRef tyData, Location loc) const override;
@ -86,8 +152,9 @@ public:
NamedAttribute argAttr) override;
private:
llvm::LLVMContext llvmContext;
llvm::Module module;
friend LLVMType;
std::unique_ptr<detail::LLVMDialectImpl> impl;
};
} // end namespace LLVM

View File

@ -31,11 +31,12 @@ class IntegerType;
class LLVMContext;
class Module;
class Type;
}
} // namespace llvm
namespace mlir {
namespace LLVM {
class LLVMDialect;
class LLVMType;
}
/// Conversion from the Standard dialect to the LLVM IR dialect. Provides hooks
@ -55,6 +56,9 @@ public:
/// Returns the LLVM context.
llvm::LLVMContext &getLLVMContext();
/// Returns the LLVM dialect.
LLVM::LLVMDialect *getDialect() { return llvmDialect; }
protected:
/// Add a set of converters to the given pattern list. Store the module
/// associated with the dialect for further type conversion.
@ -119,13 +123,13 @@ private:
// Get the LLVM representation of the index type based on the bitwidth of the
// pointer as defined by the data layout of the module.
llvm::IntegerType *getIndexType();
LLVM::LLVMType getIndexType();
// Wrap the given LLVM IR type into an LLVM IR dialect type.
Type wrap(llvm::Type *llvmType);
// Extract an LLVM IR type from the LLVM IR dialect type.
llvm::Type *unwrap(Type type);
// Extract an LLVM IR dialect type.
LLVM::LLVMType unwrap(Type type);
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides

View File

@ -29,40 +29,12 @@
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::LLVM;
namespace mlir {
namespace LLVM {
namespace detail {
struct LLVMTypeStorage : public ::mlir::TypeStorage {
LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
// LLVM types are pointer-unique.
using KeyTy = llvm::Type *;
bool operator==(const KeyTy &key) const { return key == underlyingType; }
static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
llvm::Type *ty) {
return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
}
llvm::Type *underlyingType;
};
} // end namespace detail
} // end namespace LLVM
} // end namespace mlir
LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
}
llvm::Type *LLVMType::getUnderlyingType() const {
return getImpl()->underlyingType;
}
static void printLLVMBinaryOp(OpAsmPrinter *p, Operation *op) {
// Fallback to the generic form if the op is not well-formed (may happen
// during incomplete rewrites, and used for debugging).
@ -161,14 +133,13 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
llvm::Type *llvmResultType = llvm::Type::getInt1Ty(dialect->getLLVMContext());
auto resultType = LLVMType::getInt1Ty(dialect);
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type");
if (argType.getUnderlyingType()->isVectorTy())
llvmResultType = llvm::VectorType::get(
llvmResultType, argType.getUnderlyingType()->getVectorNumElements());
auto resultType = builder.getType<LLVM::LLVMType>(llvmResultType);
resultType = LLVMType::getVectorTy(
resultType, argType.getUnderlyingType()->getVectorNumElements());
result->attributes = attrs;
result->addTypes({resultType});
@ -180,9 +151,7 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
//===----------------------------------------------------------------------===//
static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
auto *llvmPtrTy = op.getType().cast<LLVM::LLVMType>().getUnderlyingType();
auto *llvmElemTy = llvm::cast<llvm::PointerType>(llvmPtrTy)->getElementType();
auto elemTy = LLVM::LLVMType::get(op.getContext(), llvmElemTy);
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
op.getContext());
@ -291,13 +260,10 @@ static Type getLoadStoreElementType(OpAsmParser *parser, Type type,
if (!llvmTy)
return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
nullptr;
auto *llvmPtrTy = dyn_cast<llvm::PointerType>(llvmTy.getUnderlyingType());
if (!llvmPtrTy)
if (!llvmTy.getUnderlyingType()->isPointerTy())
return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"),
nullptr;
auto elemTy = LLVM::LLVMType::get(parser->getBuilder().getContext(),
llvmPtrTy->getElementType());
return elemTy;
return llvmTy.getPointerElementTy();
}
// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
@ -465,33 +431,28 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
Builder &builder = parser->getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
llvm::Type *llvmResultType;
Type wrappedResultType;
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
llvmResultType = llvm::Type::getVoidTy(llvmDialect->getLLVMContext());
wrappedResultType = builder.getType<LLVM::LLVMType>(llvmResultType);
llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
} else {
wrappedResultType = funcType.getResult(0);
auto wrappedLLVMResultType = wrappedResultType.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMResultType)
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
return parser->emitError(trailingTypeLoc,
"expected result to have LLVM type");
llvmResultType = wrappedLLVMResultType.getUnderlyingType();
}
SmallVector<llvm::Type *, 8> argTypes;
SmallVector<LLVM::LLVMType, 8> argTypes;
argTypes.reserve(funcType.getNumInputs());
for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser->emitError(trailingTypeLoc,
"expected LLVM types as inputs");
argTypes.push_back(argType.getUnderlyingType());
argTypes.push_back(argType);
}
auto *llvmFuncType = llvm::FunctionType::get(llvmResultType, argTypes,
/*isVarArg=*/false);
auto wrappedFuncType =
builder.getType<LLVM::LLVMType>(llvmFuncType->getPointerTo());
auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
/*isVarArg=*/false);
auto wrappedFuncType = llvmFuncType.getPointerTo();
auto funcArguments =
ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
@ -505,7 +466,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
parser->getNameLoc(), result->operands))
return failure();
result->addTypes(wrappedResultType);
result->addTypes(llvmResultType);
}
result->attributes = attrs;
@ -544,7 +505,6 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
// type by taking the element type, indexed by the position attribute for
// stuctures. Check the position index before accessing, it is supposed to be
// in bounds.
llvm::Type *llvmContainerType = wrappedContainerType.getUnderlyingType();
for (Attribute subAttr : positionArrayAttr) {
auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
if (!positionElementAttr)
@ -552,27 +512,27 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
"expected an array of integer literals"),
nullptr;
int position = positionElementAttr.getInt();
auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
if (llvmContainerType->isArrayTy()) {
if (position < 0 || static_cast<unsigned>(position) >=
llvmContainerType->getArrayNumElements())
return parser->emitError(attributeLoc, "position out of bounds"),
nullptr;
llvmContainerType = llvmContainerType->getArrayElementType();
wrappedContainerType = wrappedContainerType.getArrayElementType();
} else if (llvmContainerType->isStructTy()) {
if (position < 0 || static_cast<unsigned>(position) >=
llvmContainerType->getStructNumElements())
return parser->emitError(attributeLoc, "position out of bounds"),
nullptr;
llvmContainerType = llvmContainerType->getStructElementType(position);
wrappedContainerType =
wrappedContainerType.getStructElementType(position);
} else {
return parser->emitError(typeLoc,
"expected wrapped LLVM IR structure/array type"),
nullptr;
}
}
Builder &builder = parser->getBuilder();
return builder.getType<LLVM::LLVMType>(llvmContainerType);
return wrappedContainerType;
}
// <operation> ::= `llvm.extractvalue` ssa-use
@ -730,8 +690,7 @@ static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) {
Builder &builder = parser->getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto i1Type = builder.getType<LLVM::LLVMType>(
llvm::Type::getInt1Ty(llvmDialect->getLLVMContext()));
auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
if (parser->parseOperand(condition) || parser->parseComma() ||
parser->parseSuccessorAndUseList(trueDest, trueOperands) ||
@ -844,9 +803,26 @@ static ParseResult parseConstantOp(OpAsmParser *parser,
// LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
namespace mlir {
namespace LLVM {
namespace detail {
struct LLVMDialectImpl {
LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
llvm::LLVMContext llvmContext;
llvm::Module module;
/// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
/// multi-threaded and requires locked access to prevent race conditions.
llvm::sys::SmartMutex<true> mutex;
};
} // end namespace detail
} // end namespace LLVM
} // end namespace mlir
LLVMDialect::LLVMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context),
module("LLVMDialectModule", llvmContext) {
impl(new detail::LLVMDialectImpl()) {
addTypes<LLVMType>();
addOperations<
#define GET_OP_LIST
@ -857,13 +833,21 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
allowUnknownOperations();
}
LLVMDialect::~LLVMDialect() {}
#define GET_OP_CLASSES
#include "mlir/LLVMIR/LLVMOps.cpp.inc"
llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
/// Parse a type registered to this dialect.
Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
// LLVM is not thread-safe, so lock access to it.
llvm::sys::SmartScopedLock<true> lock(impl->mutex);
llvm::SMDiagnostic errorMessage;
llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
if (!type)
return (getContext()->emitError(loc, errorMessage.getMessage()), nullptr);
return LLVMType::get(getContext(), type);
@ -889,3 +873,126 @@ LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
}
static DialectRegistration<LLVMDialect> llvmDialect;
//===----------------------------------------------------------------------===//
// LLVMType.
//===----------------------------------------------------------------------===//
namespace mlir {
namespace LLVM {
namespace detail {
struct LLVMTypeStorage : public ::mlir::TypeStorage {
LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
// LLVM types are pointer-unique.
using KeyTy = llvm::Type *;
bool operator==(const KeyTy &key) const { return key == underlyingType; }
static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
llvm::Type *ty) {
return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
}
llvm::Type *underlyingType;
};
} // end namespace detail
} // end namespace LLVM
} // end namespace mlir
LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
}
LLVMDialect &LLVMType::getDialect() {
return static_cast<LLVMDialect &>(Type::getDialect());
}
llvm::Type *LLVMType::getUnderlyingType() const {
return getImpl()->underlyingType;
}
/// Array type utilities.
LLVMType LLVMType::getArrayElementType() {
return get(getContext(), getUnderlyingType()->getArrayElementType());
}
/// Pointer type utilities.
LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
// Lock access to the dialect as this may modify the LLVM context.
llvm::sys::SmartScopedLock<true> lock(getDialect().impl->mutex);
return get(getContext(), getUnderlyingType()->getPointerTo(addrSpace));
}
LLVMType LLVMType::getPointerElementTy() {
return get(getContext(), getUnderlyingType()->getPointerElementType());
}
/// Struct type utilities.
LLVMType LLVMType::getStructElementType(unsigned i) {
return get(getContext(), getUnderlyingType()->getStructElementType(i));
}
/// Utilities used to generate floating point types.
LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
return get(dialect->getContext(),
llvm::Type::getDoubleTy(dialect->getLLVMContext()));
}
LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
return get(dialect->getContext(),
llvm::Type::getFloatTy(dialect->getLLVMContext()));
}
LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
return get(dialect->getContext(),
llvm::Type::getHalfTy(dialect->getLLVMContext()));
}
/// Utilities used to generate integer types.
LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
// Lock access to the dialect as this may modify the LLVM context.
llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
return get(dialect->getContext(),
llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits));
}
/// Utilities used to generate other miscellaneous types.
LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
// Lock access to the dialect as this may modify the LLVM context.
llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
return get(
elementType.getContext(),
llvm::ArrayType::get(elementType.getUnderlyingType(), numElements));
}
LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
bool isVarArg) {
SmallVector<llvm::Type *, 8> llvmParams;
for (auto param : params)
llvmParams.push_back(param.getUnderlyingType());
// Lock access to the dialect as this may modify the LLVM context.
llvm::sys::SmartScopedLock<true> lock(result.getDialect().impl->mutex);
return get(result.getContext(),
llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
isVarArg));
}
LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
ArrayRef<LLVMType> elements, bool isPacked) {
SmallVector<llvm::Type *, 8> llvmElements;
for (auto elt : elements)
llvmElements.push_back(elt.getUnderlyingType());
// Lock access to the dialect as this may modify the LLVM context.
llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
return get(
dialect->getContext(),
llvm::StructType::get(dialect->getLLVMContext(), llvmElements, isPacked));
}
LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
// Lock access to the dialect as this may modify the LLVM context.
llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
return get(
elementType.getContext(),
llvm::VectorType::get(elementType.getUnderlyingType(), numElements));
}
LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
return get(dialect->getContext(),
llvm::Type::getVoidTy(dialect->getLLVMContext()));
}

View File

@ -45,46 +45,37 @@ llvm::LLVMContext &LLVMLowering::getLLVMContext() {
return module->getContext();
}
// Wrap the given LLVM IR type into an LLVM IR dialect type.
Type LLVMLowering::wrap(llvm::Type *llvmType) {
return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType);
}
// Extract an LLVM IR type from the LLVM IR dialect type.
llvm::Type *LLVMLowering::unwrap(Type type) {
LLVM::LLVMType LLVMLowering::unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMType)
return mlirContext->emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type"),
nullptr;
return wrappedLLVMType.getUnderlyingType();
mlirContext->emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
return wrappedLLVMType;
}
llvm::IntegerType *LLVMLowering::getIndexType() {
return llvm::IntegerType::get(llvmDialect->getLLVMContext(),
module->getDataLayout().getPointerSizeInBits());
LLVM::LLVMType LLVMLowering::getIndexType() {
return LLVM::LLVMType::getIntNTy(
llvmDialect, module->getDataLayout().getPointerSizeInBits());
}
Type LLVMLowering::convertIndexType(IndexType type) {
return wrap(getIndexType());
}
Type LLVMLowering::convertIndexType(IndexType type) { return getIndexType(); }
Type LLVMLowering::convertIntegerType(IntegerType type) {
return wrap(
llvm::Type::getIntNTy(llvmDialect->getLLVMContext(), type.getWidth()));
return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
}
Type LLVMLowering::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
return wrap(llvm::Type::getFloatTy(llvmDialect->getLLVMContext()));
return LLVM::LLVMType::getFloatTy(llvmDialect);
case mlir::StandardTypes::F64:
return wrap(llvm::Type::getDoubleTy(llvmDialect->getLLVMContext()));
return LLVM::LLVMType::getDoubleTy(llvmDialect);
case mlir::StandardTypes::F16:
return wrap(llvm::Type::getHalfTy(llvmDialect->getLLVMContext()));
return LLVM::LLVMType::getHalfTy(llvmDialect);
case mlir::StandardTypes::BF16: {
auto *mlirContext = llvmDialect->getContext();
return mlirContext->emitError(UnknownLoc::get(mlirContext),
@ -102,7 +93,7 @@ Type LLVMLowering::convertFloatType(FloatType type) {
// they are into an LLVM StructType in their order of appearance.
Type LLVMLowering::convertFunctionType(FunctionType type) {
// Convert argument types one by one and check for errors.
SmallVector<llvm::Type *, 8> argTypes;
SmallVector<LLVM::LLVMType, 8> argTypes;
for (auto t : type.getInputs()) {
auto converted = convertType(t);
if (!converted)
@ -113,14 +104,14 @@ Type LLVMLowering::convertFunctionType(FunctionType type) {
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
llvm::Type *resultType =
LLVM::LLVMType resultType =
type.getNumResults() == 0
? llvm::Type::getVoidTy(llvmDialect->getLLVMContext())
? LLVM::LLVMType::getVoidTy(llvmDialect)
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false)
->getPointerTo());
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
.getPointerTo();
}
// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
@ -129,21 +120,21 @@ Type LLVMLowering::convertFunctionType(FunctionType type) {
// pointer to the elemental type of the MemRef and the following N elements are
// values of the Index type, one for each of N dynamic dimensions of the MemRef.
Type LLVMLowering::convertMemRefType(MemRefType type) {
llvm::Type *elementType = unwrap(convertType(type.getElementType()));
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto ptrType = elementType->getPointerTo();
auto ptrType = elementType.getPointerTo();
// Extra value for the memory space.
unsigned numDynamicSizes = type.getNumDynamicDims();
// If memref is statically-shaped we return the underlying pointer type.
if (numDynamicSizes == 0) {
return wrap(ptrType);
}
SmallVector<llvm::Type *, 8> types(numDynamicSizes + 1, getIndexType());
if (numDynamicSizes == 0)
return ptrType;
SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType());
types.front() = ptrType;
return wrap(llvm::StructType::get(llvmDialect->getLLVMContext(), types));
return LLVM::LLVMType::getStructTy(llvmDialect, types);
}
// Convert a 1D vector type to an LLVM vector type.
@ -155,9 +146,9 @@ Type LLVMLowering::convertVectorType(VectorType type) {
return {};
}
llvm::Type *elementType = unwrap(convertType(type.getElementType()));
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
return elementType
? wrap(llvm::VectorType::get(elementType, type.getShape().front()))
? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front())
: Type();
}
@ -189,8 +180,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) {
auto converted = lowering.convertType(elementType);
if (!converted)
return {};
llvm::Type *llvmType = converted.cast<LLVM::LLVMType>().getUnderlyingType();
return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo());
return converted.cast<LLVM::LLVMType>().getPointerTo();
}
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
@ -226,15 +216,13 @@ public:
// Get the MLIR type wrapping the LLVM integer type whose bit width is defined
// by the pointer size used in the LLVM module.
LLVM::LLVMType getIndexType() const {
llvm::Type *llvmType = llvm::Type::getIntNTy(
getContext(), getModule().getDataLayout().getPointerSizeInBits());
return LLVM::LLVMType::get(dialect.getContext(), llvmType);
return LLVM::LLVMType::getIntNTy(
&dialect, getModule().getDataLayout().getPointerSizeInBits());
}
// Get the MLIR type wrapping the LLVM i8* type.
LLVM::LLVMType getVoidPtrType() const {
return LLVM::LLVMType::get(dialect.getContext(),
llvm::Type::getInt8PtrTy(getContext()));
return LLVM::LLVMType::getInt8PtrTy(&dialect);
}
// Create an LLVM IR pseudo-operation defining the given index constant.
@ -478,10 +466,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
cumulativeSize)
.getResult(0);
auto structElementType = lowering.convertType(elementType);
auto elementPtrType = LLVM::LLVMType::get(
op->getContext(), structElementType.cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo());
auto elementPtrType =
structElementType.cast<LLVM::LLVMType>().getPointerTo();
allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
ArrayRef<Value *>(allocated));
@ -530,14 +516,9 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
}
auto *type =
operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType();
auto hasStaticShape = type->isPointerTy();
Type elementPtrType =
(hasStaticShape)
? rewriter.getType<LLVM::LLVMType>(type)
: rewriter.getType<LLVM::LLVMType>(
cast<llvm::StructType>(type)->getStructElementType(0));
auto type = operands[0]->getType().cast<LLVM::LLVMType>();
auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
Value *bufferPtr = extractMemRefElementPtr(
rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
Value *casted = rewriter.create<LLVM::BitcastOp>(
@ -964,10 +945,6 @@ Type LLVMLowering::convertType(Type t) {
return {};
}
static llvm::Type *unwrapType(Type type) {
return type.cast<LLVM::LLVMType>().getUnderlyingType();
}
// Create an LLVM IR structure type if there is more than one result.
Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
assert(!types.empty() && "expected non-empty list of type");
@ -975,18 +952,16 @@ Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
if (types.size() == 1)
return convertType(types.front());
SmallVector<llvm::Type *, 8> resultTypes;
SmallVector<LLVM::LLVMType, 8> resultTypes;
resultTypes.reserve(types.size());
for (auto t : types) {
Type converted = convertType(t);
auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted)
return {};
resultTypes.push_back(unwrapType(converted));
resultTypes.push_back(converted);
}
return LLVM::LLVMType::get(
llvmDialect->getContext(),
llvm::StructType::get(llvmDialect->getLLVMContext(), resultTypes));
return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
}
// Convert function signatures using the stored LLVM IR module.

View File

@ -64,12 +64,10 @@ using llvm_select = ValueBuilder<LLVM::SelectOp>;
using icmp = ValueBuilder<LLVM::ICmpOp>;
template <typename T>
static llvm::Type *getPtrToElementType(T containerType,
LLVMLowering &lowering) {
static LLVMType getPtrToElementType(T containerType, LLVMLowering &lowering) {
return lowering.convertType(containerType.getElementType())
.template cast<LLVMType>()
.getUnderlyingType()
->getPointerTo();
.getPointerTo();
}
// Convert the given type to the LLVM IR Dialect type. The following
@ -82,9 +80,8 @@ static llvm::Type *getPtrToElementType(T containerType,
// containing the respective dynamic values.
static Type convertLinalgType(Type t, LLVMLowering &lowering) {
auto *context = t.getContext();
auto *int64Ty = lowering.convertType(IntegerType::get(64, context))
.cast<LLVM::LLVMType>()
.getUnderlyingType();
auto int64Ty = lowering.convertType(IntegerType::get(64, context))
.cast<LLVM::LLVMType>();
// A buffer descriptor contains the pointer to a flat region of storage and
// the size of the region.
@ -95,9 +92,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
// int64_t size;
// };
if (auto bufferType = t.dyn_cast<BufferType>()) {
auto *ptrTy = getPtrToElementType(bufferType, lowering);
auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
return LLVMType::get(context, structTy);
auto ptrTy = getPtrToElementType(bufferType, lowering);
return LLVMType::getStructTy(ptrTy, int64Ty);
}
// Range descriptor contains the range bounds and the step as 64-bit integers.
@ -107,10 +103,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
// int64_t max;
// int64_t step;
// };
if (t.isa<RangeType>()) {
auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
return LLVMType::get(context, structTy);
}
if (t.isa<RangeType>())
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
// View descriptor contains the pointer to the data buffer, followed by a
// 64-bit integer containing the distance between the beginning of the buffer
@ -136,10 +130,9 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
// int64_t strides[Rank];
// };
if (auto viewType = t.dyn_cast<ViewType>()) {
auto *ptrTy = getPtrToElementType(viewType, lowering);
auto *arrayTy = llvm::ArrayType::get(int64Ty, viewType.getRank());
auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
return LLVMType::get(context, structTy);
auto ptrTy = getPtrToElementType(viewType, lowering);
auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
}
return Type();
@ -165,9 +158,8 @@ public:
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy = LLVM::LLVMType::get(
op->getContext(),
llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(operands[0]->getType());
// Insert the `malloc` declaration if it is not already present.
auto *module = op->getFunction()->getModule();
@ -187,8 +179,8 @@ public:
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
auto elementPtrType = rewriter.getType<LLVMType>(getPtrToElementType(
allocOp.getResult()->getType().cast<BufferType>(), lowering));
auto elementPtrType = getPtrToElementType(
allocOp.getResult()->getType().cast<BufferType>(), lowering);
auto bufferDescriptorType =
convertLinalgType(allocOp.getResult()->getType(), lowering);
@ -221,9 +213,8 @@ public:
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto voidPtrTy = LLVM::LLVMType::get(
op->getContext(),
llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
auto *module = op->getFunction()->getModule();
Function *freeFunc = module->getNamedFunction("free");
@ -235,8 +226,8 @@ public:
// Get MLIR types for extracting element pointer.
auto deallocOp = cast<BufferDeallocOp>(op);
auto elementPtrTy = rewriter.getType<LLVMType>(getPtrToElementType(
deallocOp.getOperand()->getType().cast<BufferType>(), lowering));
auto elementPtrTy = getPtrToElementType(
deallocOp.getOperand()->getType().cast<BufferType>(), lowering);
// Emit MLIR for buffer_dealloc.
edsc::ScopedContext context(rewriter, op->getLoc());
@ -298,8 +289,7 @@ public:
ArrayRef<Value *> indices,
PatternRewriter &rewriter) const {
auto loadOp = cast<Op>(op);
auto elementTy = rewriter.getType<LLVMType>(
getPtrToElementType(loadOp.getViewType(), lowering));
auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
return positionAttr(rewriter, values);
@ -425,8 +415,7 @@ public:
// Helper function to obtain the ptr of the given `view`.
auto getViewPtr = [pos, &rewriter, this](ViewType type,
Value *view) -> Value * {
auto elementPtrTy =
rewriter.getType<LLVMType>(getPtrToElementType(type, lowering));
auto elementPtrTy = getPtrToElementType(type, lowering);
return extractvalue(elementPtrTy, view, pos(0));
};
@ -512,8 +501,7 @@ public:
PatternRewriter &rewriter) const override {
auto viewOp = cast<ViewOp>(op);
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
auto elementTy = rewriter.getType<LLVMType>(
getPtrToElementType(viewOp.getViewType(), lowering));
auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {