forked from OSchip/llvm-project
Add thread-safe utilities to LLVMType to allow constructing llvm types in a multi-threaded environment. The LLVMContext is not thread-safe and directly constructing a raw llvm::Type can create situations where the LLVMContext is modified by multiple threads at the same time.
-- PiperOrigin-RevId: 249526233
This commit is contained in:
parent
c33862b0ed
commit
5953d12b95
|
@ -62,22 +62,14 @@ Type linalg::convertLinalgType(Type t) {
|
|||
// Simple conversions.
|
||||
if (t.isa<IndexType>()) {
|
||||
int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
|
||||
auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
|
||||
return LLVM::LLVMType::get(context, integerTy);
|
||||
}
|
||||
if (auto intTy = t.dyn_cast<IntegerType>()) {
|
||||
int width = intTy.getWidth();
|
||||
auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
|
||||
return LLVM::LLVMType::get(context, integerTy);
|
||||
}
|
||||
if (t.isF32()) {
|
||||
auto *floatTy = llvm::Type::getFloatTy(dialect->getLLVMContext());
|
||||
return LLVM::LLVMType::get(context, floatTy);
|
||||
}
|
||||
if (t.isF64()) {
|
||||
auto *doubleTy = llvm::Type::getDoubleTy(dialect->getLLVMContext());
|
||||
return LLVM::LLVMType::get(context, doubleTy);
|
||||
return LLVM::LLVMType::getIntNTy(dialect, width);
|
||||
}
|
||||
if (auto intTy = t.dyn_cast<IntegerType>())
|
||||
return LLVM::LLVMType::getIntNTy(dialect, intTy.getWidth());
|
||||
if (t.isF32())
|
||||
return LLVM::LLVMType::getFloatTy(dialect);
|
||||
if (t.isF64())
|
||||
return LLVM::LLVMType::getDoubleTy(dialect);
|
||||
|
||||
// Range descriptor contains the range bounds and the step as 64-bit integers.
|
||||
//
|
||||
|
@ -87,9 +79,8 @@ Type linalg::convertLinalgType(Type t) {
|
|||
// int64_t step;
|
||||
// };
|
||||
if (auto rangeTy = t.dyn_cast<linalg::RangeType>()) {
|
||||
auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
|
||||
auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
|
||||
return LLVM::LLVMType::get(context, structTy);
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
|
||||
return LLVM::LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
||||
}
|
||||
|
||||
// View descriptor contains the pointer to the data buffer, followed by a
|
||||
|
@ -116,14 +107,12 @@ Type linalg::convertLinalgType(Type t) {
|
|||
// int64_t strides[Rank];
|
||||
// };
|
||||
if (auto viewTy = t.dyn_cast<linalg::ViewType>()) {
|
||||
auto *elemTy = linalg::convertLinalgType(viewTy.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getUnderlyingType()
|
||||
->getPointerTo();
|
||||
auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
|
||||
auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
|
||||
auto *structTy = llvm::StructType::get(elemTy, int64Ty, arrayTy, arrayTy);
|
||||
return LLVM::LLVMType::get(context, structTy);
|
||||
auto elemTy = linalg::convertLinalgType(viewTy.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
|
||||
auto arrayTy = LLVM::LLVMType::getArrayTy(int64Ty, viewTy.getRank());
|
||||
return LLVM::LLVMType::getStructTy(elemTy, int64Ty, arrayTy, arrayTy);
|
||||
}
|
||||
|
||||
// All other types are kept as is.
|
||||
|
@ -217,11 +206,9 @@ public:
|
|||
if (type.hasStaticShape())
|
||||
return memref;
|
||||
|
||||
auto elementTy = LLVM::LLVMType::get(
|
||||
type.getContext(), linalg::convertLinalgType(type.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getUnderlyingType()
|
||||
->getPointerTo());
|
||||
auto elementTy = linalg::convertLinalgType(type.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
return intrinsics::extractvalue(elementTy, memref, pos(0));
|
||||
};
|
||||
|
||||
|
@ -307,11 +294,9 @@ public:
|
|||
auto sliceOp = cast<linalg::SliceOp>(op);
|
||||
auto newViewDescriptorType =
|
||||
linalg::convertLinalgType(sliceOp.getViewType());
|
||||
auto elementType = rewriter.getType<LLVM::LLVMType>(
|
||||
linalg::convertLinalgType(sliceOp.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getUnderlyingType()
|
||||
->getPointerTo());
|
||||
auto elementType = linalg::convertLinalgType(sliceOp.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
|
||||
|
||||
auto pos = [&rewriter](ArrayRef<int> values) {
|
||||
|
|
|
@ -67,11 +67,9 @@ public:
|
|||
auto loadOp = cast<Op>(op);
|
||||
auto elementType =
|
||||
loadOp.getViewType().template cast<linalg::ViewType>().getElementType();
|
||||
auto *llvmPtrType = linalg::convertLinalgType(elementType)
|
||||
.template cast<LLVM::LLVMType>()
|
||||
.getUnderlyingType()
|
||||
->getPointerTo();
|
||||
elementType = rewriter.getType<LLVM::LLVMType>(llvmPtrType);
|
||||
elementType = linalg::convertLinalgType(elementType)
|
||||
.template cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
|
||||
|
||||
auto pos = [&rewriter](ArrayRef<int> values) {
|
||||
|
|
|
@ -210,15 +210,11 @@ private:
|
|||
|
||||
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
|
||||
Builder builder(&module);
|
||||
MLIRContext *context = module.getContext();
|
||||
auto *llvmDialect =
|
||||
auto *dialect =
|
||||
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
auto &llvmModule = llvmDialect->getLLVMModule();
|
||||
llvm::IRBuilder<> llvmBuilder(llvmModule.getContext());
|
||||
|
||||
auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32));
|
||||
auto llvmI8PtrTy =
|
||||
LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo());
|
||||
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
|
||||
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
|
||||
printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
|
||||
// It should be variadic, but we don't support it fully just yet.
|
||||
|
|
|
@ -41,10 +41,12 @@ class LLVMContext;
|
|||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
class LLVMDialect;
|
||||
|
||||
namespace detail {
|
||||
struct LLVMTypeStorage;
|
||||
}
|
||||
struct LLVMDialectImpl;
|
||||
} // namespace detail
|
||||
|
||||
class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
|
||||
detail::LLVMTypeStorage> {
|
||||
|
@ -57,9 +59,72 @@ public:
|
|||
|
||||
static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }
|
||||
|
||||
static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
|
||||
|
||||
LLVMDialect &getDialect();
|
||||
llvm::Type *getUnderlyingType() const;
|
||||
|
||||
/// Array type utilities.
|
||||
LLVMType getArrayElementType();
|
||||
|
||||
/// Pointer type utilities.
|
||||
LLVMType getPointerTo(unsigned addrSpace = 0);
|
||||
LLVMType getPointerElementTy();
|
||||
|
||||
/// Struct type utilities.
|
||||
LLVMType getStructElementType(unsigned i);
|
||||
|
||||
/// Utilities used to generate floating point types.
|
||||
static LLVMType getDoubleTy(LLVMDialect *dialect);
|
||||
static LLVMType getFloatTy(LLVMDialect *dialect);
|
||||
static LLVMType getHalfTy(LLVMDialect *dialect);
|
||||
|
||||
/// Utilities used to generate integer types.
|
||||
static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
|
||||
static LLVMType getInt1Ty(LLVMDialect *dialect) {
|
||||
return getIntNTy(dialect, /*numBits=*/1);
|
||||
}
|
||||
static LLVMType getInt8Ty(LLVMDialect *dialect) {
|
||||
return getIntNTy(dialect, /*numBits=*/8);
|
||||
}
|
||||
static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
|
||||
return getInt8Ty(dialect).getPointerTo();
|
||||
}
|
||||
static LLVMType getInt16Ty(LLVMDialect *dialect) {
|
||||
return getIntNTy(dialect, /*numBits=*/16);
|
||||
}
|
||||
static LLVMType getInt32Ty(LLVMDialect *dialect) {
|
||||
return getIntNTy(dialect, /*numBits=*/32);
|
||||
}
|
||||
static LLVMType getInt64Ty(LLVMDialect *dialect) {
|
||||
return getIntNTy(dialect, /*numBits=*/64);
|
||||
}
|
||||
|
||||
/// Utilities used to generate other miscellaneous types.
|
||||
static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements);
|
||||
static LLVMType getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
|
||||
bool isVarArg);
|
||||
static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
|
||||
return getFunctionTy(result, llvm::None, isVarArg);
|
||||
}
|
||||
static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
|
||||
bool isPacked = false);
|
||||
static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
|
||||
return getStructTy(dialect, llvm::None, isPacked);
|
||||
}
|
||||
template <typename... Args>
|
||||
static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
|
||||
LLVMType>::type
|
||||
getStructTy(LLVMType elt1, Args... elts) {
|
||||
SmallVector<LLVMType, 8> fields({elt1, elts...});
|
||||
return getStructTy(&elt1.getDialect(), fields);
|
||||
}
|
||||
static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
|
||||
static LLVMType getVoidTy(LLVMDialect *dialect);
|
||||
|
||||
private:
|
||||
friend LLVMDialect;
|
||||
|
||||
/// Get an LLVM type with a pre-existing llvm type.
|
||||
static LLVMType get(MLIRContext *context, llvm::Type *llvmType);
|
||||
};
|
||||
|
||||
///// Ops /////
|
||||
|
@ -69,10 +134,11 @@ public:
|
|||
class LLVMDialect : public Dialect {
|
||||
public:
|
||||
explicit LLVMDialect(MLIRContext *context);
|
||||
~LLVMDialect();
|
||||
static StringRef getDialectNamespace() { return "llvm"; }
|
||||
|
||||
llvm::LLVMContext &getLLVMContext() { return llvmContext; }
|
||||
llvm::Module &getLLVMModule() { return module; }
|
||||
llvm::LLVMContext &getLLVMContext();
|
||||
llvm::Module &getLLVMModule();
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type parseType(StringRef tyData, Location loc) const override;
|
||||
|
@ -86,8 +152,9 @@ public:
|
|||
NamedAttribute argAttr) override;
|
||||
|
||||
private:
|
||||
llvm::LLVMContext llvmContext;
|
||||
llvm::Module module;
|
||||
friend LLVMType;
|
||||
|
||||
std::unique_ptr<detail::LLVMDialectImpl> impl;
|
||||
};
|
||||
|
||||
} // end namespace LLVM
|
||||
|
|
|
@ -31,11 +31,12 @@ class IntegerType;
|
|||
class LLVMContext;
|
||||
class Module;
|
||||
class Type;
|
||||
}
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
class LLVMDialect;
|
||||
class LLVMType;
|
||||
}
|
||||
|
||||
/// Conversion from the Standard dialect to the LLVM IR dialect. Provides hooks
|
||||
|
@ -55,6 +56,9 @@ public:
|
|||
/// Returns the LLVM context.
|
||||
llvm::LLVMContext &getLLVMContext();
|
||||
|
||||
/// Returns the LLVM dialect.
|
||||
LLVM::LLVMDialect *getDialect() { return llvmDialect; }
|
||||
|
||||
protected:
|
||||
/// Add a set of converters to the given pattern list. Store the module
|
||||
/// associated with the dialect for further type conversion.
|
||||
|
@ -119,13 +123,13 @@ private:
|
|||
|
||||
// Get the LLVM representation of the index type based on the bitwidth of the
|
||||
// pointer as defined by the data layout of the module.
|
||||
llvm::IntegerType *getIndexType();
|
||||
LLVM::LLVMType getIndexType();
|
||||
|
||||
// Wrap the given LLVM IR type into an LLVM IR dialect type.
|
||||
Type wrap(llvm::Type *llvmType);
|
||||
|
||||
// Extract an LLVM IR type from the LLVM IR dialect type.
|
||||
llvm::Type *unwrap(Type type);
|
||||
// Extract an LLVM IR dialect type.
|
||||
LLVM::LLVMType unwrap(Type type);
|
||||
};
|
||||
|
||||
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
|
||||
|
|
|
@ -29,40 +29,12 @@
|
|||
#include "llvm/IR/Attributes.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/Support/Mutex.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::LLVM;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
namespace detail {
|
||||
struct LLVMTypeStorage : public ::mlir::TypeStorage {
|
||||
LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
|
||||
|
||||
// LLVM types are pointer-unique.
|
||||
using KeyTy = llvm::Type *;
|
||||
bool operator==(const KeyTy &key) const { return key == underlyingType; }
|
||||
|
||||
static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
llvm::Type *ty) {
|
||||
return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
|
||||
}
|
||||
|
||||
llvm::Type *underlyingType;
|
||||
};
|
||||
} // end namespace detail
|
||||
} // end namespace LLVM
|
||||
} // end namespace mlir
|
||||
|
||||
LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
|
||||
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
|
||||
}
|
||||
|
||||
llvm::Type *LLVMType::getUnderlyingType() const {
|
||||
return getImpl()->underlyingType;
|
||||
}
|
||||
|
||||
static void printLLVMBinaryOp(OpAsmPrinter *p, Operation *op) {
|
||||
// Fallback to the generic form if the op is not well-formed (may happen
|
||||
// during incomplete rewrites, and used for debugging).
|
||||
|
@ -161,14 +133,13 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
|
|||
// The result type is either i1 or a vector type <? x i1> if the inputs are
|
||||
// vectors.
|
||||
auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
|
||||
llvm::Type *llvmResultType = llvm::Type::getInt1Ty(dialect->getLLVMContext());
|
||||
auto resultType = LLVMType::getInt1Ty(dialect);
|
||||
auto argType = type.dyn_cast<LLVM::LLVMType>();
|
||||
if (!argType)
|
||||
return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type");
|
||||
if (argType.getUnderlyingType()->isVectorTy())
|
||||
llvmResultType = llvm::VectorType::get(
|
||||
llvmResultType, argType.getUnderlyingType()->getVectorNumElements());
|
||||
auto resultType = builder.getType<LLVM::LLVMType>(llvmResultType);
|
||||
resultType = LLVMType::getVectorTy(
|
||||
resultType, argType.getUnderlyingType()->getVectorNumElements());
|
||||
|
||||
result->attributes = attrs;
|
||||
result->addTypes({resultType});
|
||||
|
@ -180,9 +151,7 @@ static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
|
||||
auto *llvmPtrTy = op.getType().cast<LLVM::LLVMType>().getUnderlyingType();
|
||||
auto *llvmElemTy = llvm::cast<llvm::PointerType>(llvmPtrTy)->getElementType();
|
||||
auto elemTy = LLVM::LLVMType::get(op.getContext(), llvmElemTy);
|
||||
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
||||
|
||||
auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
|
||||
op.getContext());
|
||||
|
@ -291,13 +260,10 @@ static Type getLoadStoreElementType(OpAsmParser *parser, Type type,
|
|||
if (!llvmTy)
|
||||
return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
|
||||
nullptr;
|
||||
auto *llvmPtrTy = dyn_cast<llvm::PointerType>(llvmTy.getUnderlyingType());
|
||||
if (!llvmPtrTy)
|
||||
if (!llvmTy.getUnderlyingType()->isPointerTy())
|
||||
return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"),
|
||||
nullptr;
|
||||
auto elemTy = LLVM::LLVMType::get(parser->getBuilder().getContext(),
|
||||
llvmPtrTy->getElementType());
|
||||
return elemTy;
|
||||
return llvmTy.getPointerElementTy();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
|
||||
|
@ -465,33 +431,28 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
|
|||
Builder &builder = parser->getBuilder();
|
||||
auto *llvmDialect =
|
||||
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
llvm::Type *llvmResultType;
|
||||
Type wrappedResultType;
|
||||
LLVM::LLVMType llvmResultType;
|
||||
if (funcType.getNumResults() == 0) {
|
||||
llvmResultType = llvm::Type::getVoidTy(llvmDialect->getLLVMContext());
|
||||
wrappedResultType = builder.getType<LLVM::LLVMType>(llvmResultType);
|
||||
llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||
} else {
|
||||
wrappedResultType = funcType.getResult(0);
|
||||
auto wrappedLLVMResultType = wrappedResultType.dyn_cast<LLVM::LLVMType>();
|
||||
if (!wrappedLLVMResultType)
|
||||
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
|
||||
if (!llvmResultType)
|
||||
return parser->emitError(trailingTypeLoc,
|
||||
"expected result to have LLVM type");
|
||||
llvmResultType = wrappedLLVMResultType.getUnderlyingType();
|
||||
}
|
||||
|
||||
SmallVector<llvm::Type *, 8> argTypes;
|
||||
SmallVector<LLVM::LLVMType, 8> argTypes;
|
||||
argTypes.reserve(funcType.getNumInputs());
|
||||
for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
|
||||
auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
|
||||
if (!argType)
|
||||
return parser->emitError(trailingTypeLoc,
|
||||
"expected LLVM types as inputs");
|
||||
argTypes.push_back(argType.getUnderlyingType());
|
||||
argTypes.push_back(argType);
|
||||
}
|
||||
auto *llvmFuncType = llvm::FunctionType::get(llvmResultType, argTypes,
|
||||
/*isVarArg=*/false);
|
||||
auto wrappedFuncType =
|
||||
builder.getType<LLVM::LLVMType>(llvmFuncType->getPointerTo());
|
||||
auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
|
||||
/*isVarArg=*/false);
|
||||
auto wrappedFuncType = llvmFuncType.getPointerTo();
|
||||
|
||||
auto funcArguments =
|
||||
ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
|
||||
|
@ -505,7 +466,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
|
|||
parser->getNameLoc(), result->operands))
|
||||
return failure();
|
||||
|
||||
result->addTypes(wrappedResultType);
|
||||
result->addTypes(llvmResultType);
|
||||
}
|
||||
|
||||
result->attributes = attrs;
|
||||
|
@ -544,7 +505,6 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
|
|||
// type by taking the element type, indexed by the position attribute for
|
||||
// stuctures. Check the position index before accessing, it is supposed to be
|
||||
// in bounds.
|
||||
llvm::Type *llvmContainerType = wrappedContainerType.getUnderlyingType();
|
||||
for (Attribute subAttr : positionArrayAttr) {
|
||||
auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
|
||||
if (!positionElementAttr)
|
||||
|
@ -552,27 +512,27 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
|
|||
"expected an array of integer literals"),
|
||||
nullptr;
|
||||
int position = positionElementAttr.getInt();
|
||||
auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
|
||||
if (llvmContainerType->isArrayTy()) {
|
||||
if (position < 0 || static_cast<unsigned>(position) >=
|
||||
llvmContainerType->getArrayNumElements())
|
||||
return parser->emitError(attributeLoc, "position out of bounds"),
|
||||
nullptr;
|
||||
llvmContainerType = llvmContainerType->getArrayElementType();
|
||||
wrappedContainerType = wrappedContainerType.getArrayElementType();
|
||||
} else if (llvmContainerType->isStructTy()) {
|
||||
if (position < 0 || static_cast<unsigned>(position) >=
|
||||
llvmContainerType->getStructNumElements())
|
||||
return parser->emitError(attributeLoc, "position out of bounds"),
|
||||
nullptr;
|
||||
llvmContainerType = llvmContainerType->getStructElementType(position);
|
||||
wrappedContainerType =
|
||||
wrappedContainerType.getStructElementType(position);
|
||||
} else {
|
||||
return parser->emitError(typeLoc,
|
||||
"expected wrapped LLVM IR structure/array type"),
|
||||
nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Builder &builder = parser->getBuilder();
|
||||
return builder.getType<LLVM::LLVMType>(llvmContainerType);
|
||||
return wrappedContainerType;
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.extractvalue` ssa-use
|
||||
|
@ -730,8 +690,7 @@ static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) {
|
|||
Builder &builder = parser->getBuilder();
|
||||
auto *llvmDialect =
|
||||
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
auto i1Type = builder.getType<LLVM::LLVMType>(
|
||||
llvm::Type::getInt1Ty(llvmDialect->getLLVMContext()));
|
||||
auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
||||
|
||||
if (parser->parseOperand(condition) || parser->parseComma() ||
|
||||
parser->parseSuccessorAndUseList(trueDest, trueOperands) ||
|
||||
|
@ -844,9 +803,26 @@ static ParseResult parseConstantOp(OpAsmParser *parser,
|
|||
// LLVMDialect initialization, type parsing, and registration.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
namespace detail {
|
||||
struct LLVMDialectImpl {
|
||||
LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
llvm::Module module;
|
||||
|
||||
/// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
|
||||
/// multi-threaded and requires locked access to prevent race conditions.
|
||||
llvm::sys::SmartMutex<true> mutex;
|
||||
};
|
||||
} // end namespace detail
|
||||
} // end namespace LLVM
|
||||
} // end namespace mlir
|
||||
|
||||
LLVMDialect::LLVMDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context),
|
||||
module("LLVMDialectModule", llvmContext) {
|
||||
impl(new detail::LLVMDialectImpl()) {
|
||||
addTypes<LLVMType>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
@ -857,13 +833,21 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
|
|||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
LLVMDialect::~LLVMDialect() {}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/LLVMIR/LLVMOps.cpp.inc"
|
||||
|
||||
llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
|
||||
llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
|
||||
// LLVM is not thread-safe, so lock access to it.
|
||||
llvm::sys::SmartScopedLock<true> lock(impl->mutex);
|
||||
|
||||
llvm::SMDiagnostic errorMessage;
|
||||
llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
|
||||
llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
|
||||
if (!type)
|
||||
return (getContext()->emitError(loc, errorMessage.getMessage()), nullptr);
|
||||
return LLVMType::get(getContext(), type);
|
||||
|
@ -889,3 +873,126 @@ LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
|
|||
}
|
||||
|
||||
static DialectRegistration<LLVMDialect> llvmDialect;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LLVMType.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
namespace detail {
|
||||
struct LLVMTypeStorage : public ::mlir::TypeStorage {
|
||||
LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
|
||||
|
||||
// LLVM types are pointer-unique.
|
||||
using KeyTy = llvm::Type *;
|
||||
bool operator==(const KeyTy &key) const { return key == underlyingType; }
|
||||
|
||||
static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
llvm::Type *ty) {
|
||||
return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
|
||||
}
|
||||
|
||||
llvm::Type *underlyingType;
|
||||
};
|
||||
} // end namespace detail
|
||||
} // end namespace LLVM
|
||||
} // end namespace mlir
|
||||
|
||||
LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
|
||||
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
|
||||
}
|
||||
|
||||
LLVMDialect &LLVMType::getDialect() {
|
||||
return static_cast<LLVMDialect &>(Type::getDialect());
|
||||
}
|
||||
|
||||
llvm::Type *LLVMType::getUnderlyingType() const {
|
||||
return getImpl()->underlyingType;
|
||||
}
|
||||
|
||||
/// Array type utilities.
|
||||
LLVMType LLVMType::getArrayElementType() {
|
||||
return get(getContext(), getUnderlyingType()->getArrayElementType());
|
||||
}
|
||||
|
||||
/// Pointer type utilities.
|
||||
LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
|
||||
// Lock access to the dialect as this may modify the LLVM context.
|
||||
llvm::sys::SmartScopedLock<true> lock(getDialect().impl->mutex);
|
||||
return get(getContext(), getUnderlyingType()->getPointerTo(addrSpace));
|
||||
}
|
||||
LLVMType LLVMType::getPointerElementTy() {
|
||||
return get(getContext(), getUnderlyingType()->getPointerElementType());
|
||||
}
|
||||
|
||||
/// Struct type utilities.
|
||||
LLVMType LLVMType::getStructElementType(unsigned i) {
|
||||
return get(getContext(), getUnderlyingType()->getStructElementType(i));
|
||||
}
|
||||
|
||||
/// Utilities used to generate floating point types.
|
||||
LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
|
||||
return get(dialect->getContext(),
|
||||
llvm::Type::getDoubleTy(dialect->getLLVMContext()));
|
||||
}
|
||||
LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
|
||||
return get(dialect->getContext(),
|
||||
llvm::Type::getFloatTy(dialect->getLLVMContext()));
|
||||
}
|
||||
LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
|
||||
return get(dialect->getContext(),
|
||||
llvm::Type::getHalfTy(dialect->getLLVMContext()));
|
||||
}
|
||||
|
||||
/// Utilities used to generate integer types.
|
||||
LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
|
||||
// Lock access to the dialect as this may modify the LLVM context.
|
||||
llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
|
||||
return get(dialect->getContext(),
|
||||
llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits));
|
||||
}
|
||||
|
||||
/// Utilities used to generate other miscellaneous types.
|
||||
LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
|
||||
// Lock access to the dialect as this may modify the LLVM context.
|
||||
llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
|
||||
return get(
|
||||
elementType.getContext(),
|
||||
llvm::ArrayType::get(elementType.getUnderlyingType(), numElements));
|
||||
}
|
||||
LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
|
||||
bool isVarArg) {
|
||||
SmallVector<llvm::Type *, 8> llvmParams;
|
||||
for (auto param : params)
|
||||
llvmParams.push_back(param.getUnderlyingType());
|
||||
|
||||
// Lock access to the dialect as this may modify the LLVM context.
|
||||
llvm::sys::SmartScopedLock<true> lock(result.getDialect().impl->mutex);
|
||||
return get(result.getContext(),
|
||||
llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
|
||||
isVarArg));
|
||||
}
|
||||
LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
|
||||
ArrayRef<LLVMType> elements, bool isPacked) {
|
||||
SmallVector<llvm::Type *, 8> llvmElements;
|
||||
for (auto elt : elements)
|
||||
llvmElements.push_back(elt.getUnderlyingType());
|
||||
|
||||
// Lock access to the dialect as this may modify the LLVM context.
|
||||
llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
|
||||
return get(
|
||||
dialect->getContext(),
|
||||
llvm::StructType::get(dialect->getLLVMContext(), llvmElements, isPacked));
|
||||
}
|
||||
LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
|
||||
// Lock access to the dialect as this may modify the LLVM context.
|
||||
llvm::sys::SmartScopedLock<true> lock(elementType.getDialect().impl->mutex);
|
||||
return get(
|
||||
elementType.getContext(),
|
||||
llvm::VectorType::get(elementType.getUnderlyingType(), numElements));
|
||||
}
|
||||
LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
|
||||
return get(dialect->getContext(),
|
||||
llvm::Type::getVoidTy(dialect->getLLVMContext()));
|
||||
}
|
||||
|
|
|
@ -45,46 +45,37 @@ llvm::LLVMContext &LLVMLowering::getLLVMContext() {
|
|||
return module->getContext();
|
||||
}
|
||||
|
||||
// Wrap the given LLVM IR type into an LLVM IR dialect type.
|
||||
Type LLVMLowering::wrap(llvm::Type *llvmType) {
|
||||
return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType);
|
||||
}
|
||||
|
||||
// Extract an LLVM IR type from the LLVM IR dialect type.
|
||||
llvm::Type *LLVMLowering::unwrap(Type type) {
|
||||
LLVM::LLVMType LLVMLowering::unwrap(Type type) {
|
||||
if (!type)
|
||||
return nullptr;
|
||||
auto *mlirContext = type.getContext();
|
||||
auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
|
||||
if (!wrappedLLVMType)
|
||||
return mlirContext->emitError(UnknownLoc::get(mlirContext),
|
||||
"conversion resulted in a non-LLVM type"),
|
||||
nullptr;
|
||||
return wrappedLLVMType.getUnderlyingType();
|
||||
mlirContext->emitError(UnknownLoc::get(mlirContext),
|
||||
"conversion resulted in a non-LLVM type");
|
||||
return wrappedLLVMType;
|
||||
}
|
||||
|
||||
llvm::IntegerType *LLVMLowering::getIndexType() {
|
||||
return llvm::IntegerType::get(llvmDialect->getLLVMContext(),
|
||||
module->getDataLayout().getPointerSizeInBits());
|
||||
LLVM::LLVMType LLVMLowering::getIndexType() {
|
||||
return LLVM::LLVMType::getIntNTy(
|
||||
llvmDialect, module->getDataLayout().getPointerSizeInBits());
|
||||
}
|
||||
|
||||
Type LLVMLowering::convertIndexType(IndexType type) {
|
||||
return wrap(getIndexType());
|
||||
}
|
||||
Type LLVMLowering::convertIndexType(IndexType type) { return getIndexType(); }
|
||||
|
||||
Type LLVMLowering::convertIntegerType(IntegerType type) {
|
||||
return wrap(
|
||||
llvm::Type::getIntNTy(llvmDialect->getLLVMContext(), type.getWidth()));
|
||||
return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
|
||||
}
|
||||
|
||||
Type LLVMLowering::convertFloatType(FloatType type) {
|
||||
switch (type.getKind()) {
|
||||
case mlir::StandardTypes::F32:
|
||||
return wrap(llvm::Type::getFloatTy(llvmDialect->getLLVMContext()));
|
||||
return LLVM::LLVMType::getFloatTy(llvmDialect);
|
||||
case mlir::StandardTypes::F64:
|
||||
return wrap(llvm::Type::getDoubleTy(llvmDialect->getLLVMContext()));
|
||||
return LLVM::LLVMType::getDoubleTy(llvmDialect);
|
||||
case mlir::StandardTypes::F16:
|
||||
return wrap(llvm::Type::getHalfTy(llvmDialect->getLLVMContext()));
|
||||
return LLVM::LLVMType::getHalfTy(llvmDialect);
|
||||
case mlir::StandardTypes::BF16: {
|
||||
auto *mlirContext = llvmDialect->getContext();
|
||||
return mlirContext->emitError(UnknownLoc::get(mlirContext),
|
||||
|
@ -102,7 +93,7 @@ Type LLVMLowering::convertFloatType(FloatType type) {
|
|||
// they are into an LLVM StructType in their order of appearance.
|
||||
Type LLVMLowering::convertFunctionType(FunctionType type) {
|
||||
// Convert argument types one by one and check for errors.
|
||||
SmallVector<llvm::Type *, 8> argTypes;
|
||||
SmallVector<LLVM::LLVMType, 8> argTypes;
|
||||
for (auto t : type.getInputs()) {
|
||||
auto converted = convertType(t);
|
||||
if (!converted)
|
||||
|
@ -113,14 +104,14 @@ Type LLVMLowering::convertFunctionType(FunctionType type) {
|
|||
// If function does not return anything, create the void result type,
|
||||
// if it returns on element, convert it, otherwise pack the result types into
|
||||
// a struct.
|
||||
llvm::Type *resultType =
|
||||
LLVM::LLVMType resultType =
|
||||
type.getNumResults() == 0
|
||||
? llvm::Type::getVoidTy(llvmDialect->getLLVMContext())
|
||||
? LLVM::LLVMType::getVoidTy(llvmDialect)
|
||||
: unwrap(packFunctionResults(type.getResults()));
|
||||
if (!resultType)
|
||||
return {};
|
||||
return wrap(llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false)
|
||||
->getPointerTo());
|
||||
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
|
||||
.getPointerTo();
|
||||
}
|
||||
|
||||
// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
|
||||
|
@ -129,21 +120,21 @@ Type LLVMLowering::convertFunctionType(FunctionType type) {
|
|||
// pointer to the elemental type of the MemRef and the following N elements are
|
||||
// values of the Index type, one for each of N dynamic dimensions of the MemRef.
|
||||
Type LLVMLowering::convertMemRefType(MemRefType type) {
|
||||
llvm::Type *elementType = unwrap(convertType(type.getElementType()));
|
||||
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
|
||||
if (!elementType)
|
||||
return {};
|
||||
auto ptrType = elementType->getPointerTo();
|
||||
auto ptrType = elementType.getPointerTo();
|
||||
|
||||
// Extra value for the memory space.
|
||||
unsigned numDynamicSizes = type.getNumDynamicDims();
|
||||
// If memref is statically-shaped we return the underlying pointer type.
|
||||
if (numDynamicSizes == 0) {
|
||||
return wrap(ptrType);
|
||||
}
|
||||
SmallVector<llvm::Type *, 8> types(numDynamicSizes + 1, getIndexType());
|
||||
if (numDynamicSizes == 0)
|
||||
return ptrType;
|
||||
|
||||
SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType());
|
||||
types.front() = ptrType;
|
||||
|
||||
return wrap(llvm::StructType::get(llvmDialect->getLLVMContext(), types));
|
||||
return LLVM::LLVMType::getStructTy(llvmDialect, types);
|
||||
}
|
||||
|
||||
// Convert a 1D vector type to an LLVM vector type.
|
||||
|
@ -155,9 +146,9 @@ Type LLVMLowering::convertVectorType(VectorType type) {
|
|||
return {};
|
||||
}
|
||||
|
||||
llvm::Type *elementType = unwrap(convertType(type.getElementType()));
|
||||
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
|
||||
return elementType
|
||||
? wrap(llvm::VectorType::get(elementType, type.getShape().front()))
|
||||
? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front())
|
||||
: Type();
|
||||
}
|
||||
|
||||
|
@ -189,8 +180,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMLowering &lowering) {
|
|||
auto converted = lowering.convertType(elementType);
|
||||
if (!converted)
|
||||
return {};
|
||||
llvm::Type *llvmType = converted.cast<LLVM::LLVMType>().getUnderlyingType();
|
||||
return LLVM::LLVMType::get(t.getContext(), llvmType->getPointerTo());
|
||||
return converted.cast<LLVM::LLVMType>().getPointerTo();
|
||||
}
|
||||
|
||||
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
|
||||
|
@ -226,15 +216,13 @@ public:
|
|||
// Get the MLIR type wrapping the LLVM integer type whose bit width is defined
|
||||
// by the pointer size used in the LLVM module.
|
||||
LLVM::LLVMType getIndexType() const {
|
||||
llvm::Type *llvmType = llvm::Type::getIntNTy(
|
||||
getContext(), getModule().getDataLayout().getPointerSizeInBits());
|
||||
return LLVM::LLVMType::get(dialect.getContext(), llvmType);
|
||||
return LLVM::LLVMType::getIntNTy(
|
||||
&dialect, getModule().getDataLayout().getPointerSizeInBits());
|
||||
}
|
||||
|
||||
// Get the MLIR type wrapping the LLVM i8* type.
|
||||
LLVM::LLVMType getVoidPtrType() const {
|
||||
return LLVM::LLVMType::get(dialect.getContext(),
|
||||
llvm::Type::getInt8PtrTy(getContext()));
|
||||
return LLVM::LLVMType::getInt8PtrTy(&dialect);
|
||||
}
|
||||
|
||||
// Create an LLVM IR pseudo-operation defining the given index constant.
|
||||
|
@ -478,10 +466,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
cumulativeSize)
|
||||
.getResult(0);
|
||||
auto structElementType = lowering.convertType(elementType);
|
||||
auto elementPtrType = LLVM::LLVMType::get(
|
||||
op->getContext(), structElementType.cast<LLVM::LLVMType>()
|
||||
.getUnderlyingType()
|
||||
->getPointerTo());
|
||||
auto elementPtrType =
|
||||
structElementType.cast<LLVM::LLVMType>().getPointerTo();
|
||||
allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
|
||||
ArrayRef<Value *>(allocated));
|
||||
|
||||
|
@ -530,14 +516,9 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
|
||||
}
|
||||
|
||||
auto *type =
|
||||
operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType();
|
||||
auto hasStaticShape = type->isPointerTy();
|
||||
Type elementPtrType =
|
||||
(hasStaticShape)
|
||||
? rewriter.getType<LLVM::LLVMType>(type)
|
||||
: rewriter.getType<LLVM::LLVMType>(
|
||||
cast<llvm::StructType>(type)->getStructElementType(0));
|
||||
auto type = operands[0]->getType().cast<LLVM::LLVMType>();
|
||||
auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
|
||||
Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
|
||||
Value *bufferPtr = extractMemRefElementPtr(
|
||||
rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
|
||||
Value *casted = rewriter.create<LLVM::BitcastOp>(
|
||||
|
@ -964,10 +945,6 @@ Type LLVMLowering::convertType(Type t) {
|
|||
return {};
|
||||
}
|
||||
|
||||
static llvm::Type *unwrapType(Type type) {
|
||||
return type.cast<LLVM::LLVMType>().getUnderlyingType();
|
||||
}
|
||||
|
||||
// Create an LLVM IR structure type if there is more than one result.
|
||||
Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
|
||||
assert(!types.empty() && "expected non-empty list of type");
|
||||
|
@ -975,18 +952,16 @@ Type LLVMLowering::packFunctionResults(ArrayRef<Type> types) {
|
|||
if (types.size() == 1)
|
||||
return convertType(types.front());
|
||||
|
||||
SmallVector<llvm::Type *, 8> resultTypes;
|
||||
SmallVector<LLVM::LLVMType, 8> resultTypes;
|
||||
resultTypes.reserve(types.size());
|
||||
for (auto t : types) {
|
||||
Type converted = convertType(t);
|
||||
auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
|
||||
if (!converted)
|
||||
return {};
|
||||
resultTypes.push_back(unwrapType(converted));
|
||||
resultTypes.push_back(converted);
|
||||
}
|
||||
|
||||
return LLVM::LLVMType::get(
|
||||
llvmDialect->getContext(),
|
||||
llvm::StructType::get(llvmDialect->getLLVMContext(), resultTypes));
|
||||
return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
|
||||
}
|
||||
|
||||
// Convert function signatures using the stored LLVM IR module.
|
||||
|
|
|
@ -64,12 +64,10 @@ using llvm_select = ValueBuilder<LLVM::SelectOp>;
|
|||
using icmp = ValueBuilder<LLVM::ICmpOp>;
|
||||
|
||||
template <typename T>
|
||||
static llvm::Type *getPtrToElementType(T containerType,
|
||||
LLVMLowering &lowering) {
|
||||
static LLVMType getPtrToElementType(T containerType, LLVMLowering &lowering) {
|
||||
return lowering.convertType(containerType.getElementType())
|
||||
.template cast<LLVMType>()
|
||||
.getUnderlyingType()
|
||||
->getPointerTo();
|
||||
.getPointerTo();
|
||||
}
|
||||
|
||||
// Convert the given type to the LLVM IR Dialect type. The following
|
||||
|
@ -82,9 +80,8 @@ static llvm::Type *getPtrToElementType(T containerType,
|
|||
// containing the respective dynamic values.
|
||||
static Type convertLinalgType(Type t, LLVMLowering &lowering) {
|
||||
auto *context = t.getContext();
|
||||
auto *int64Ty = lowering.convertType(IntegerType::get(64, context))
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getUnderlyingType();
|
||||
auto int64Ty = lowering.convertType(IntegerType::get(64, context))
|
||||
.cast<LLVM::LLVMType>();
|
||||
|
||||
// A buffer descriptor contains the pointer to a flat region of storage and
|
||||
// the size of the region.
|
||||
|
@ -95,9 +92,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
|
|||
// int64_t size;
|
||||
// };
|
||||
if (auto bufferType = t.dyn_cast<BufferType>()) {
|
||||
auto *ptrTy = getPtrToElementType(bufferType, lowering);
|
||||
auto *structTy = llvm::StructType::get(ptrTy, int64Ty);
|
||||
return LLVMType::get(context, structTy);
|
||||
auto ptrTy = getPtrToElementType(bufferType, lowering);
|
||||
return LLVMType::getStructTy(ptrTy, int64Ty);
|
||||
}
|
||||
|
||||
// Range descriptor contains the range bounds and the step as 64-bit integers.
|
||||
|
@ -107,10 +103,8 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
|
|||
// int64_t max;
|
||||
// int64_t step;
|
||||
// };
|
||||
if (t.isa<RangeType>()) {
|
||||
auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
|
||||
return LLVMType::get(context, structTy);
|
||||
}
|
||||
if (t.isa<RangeType>())
|
||||
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
||||
|
||||
// View descriptor contains the pointer to the data buffer, followed by a
|
||||
// 64-bit integer containing the distance between the beginning of the buffer
|
||||
|
@ -136,10 +130,9 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
|
|||
// int64_t strides[Rank];
|
||||
// };
|
||||
if (auto viewType = t.dyn_cast<ViewType>()) {
|
||||
auto *ptrTy = getPtrToElementType(viewType, lowering);
|
||||
auto *arrayTy = llvm::ArrayType::get(int64Ty, viewType.getRank());
|
||||
auto *structTy = llvm::StructType::get(ptrTy, int64Ty, arrayTy, arrayTy);
|
||||
return LLVMType::get(context, structTy);
|
||||
auto ptrTy = getPtrToElementType(viewType, lowering);
|
||||
auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
|
||||
return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
|
||||
}
|
||||
|
||||
return Type();
|
||||
|
@ -165,9 +158,8 @@ public:
|
|||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto indexType = IndexType::get(op->getContext());
|
||||
auto voidPtrTy = LLVM::LLVMType::get(
|
||||
op->getContext(),
|
||||
llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
|
||||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
auto int64Ty = lowering.convertType(operands[0]->getType());
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto *module = op->getFunction()->getModule();
|
||||
|
@ -187,8 +179,8 @@ public:
|
|||
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
|
||||
else
|
||||
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
||||
auto elementPtrType = rewriter.getType<LLVMType>(getPtrToElementType(
|
||||
allocOp.getResult()->getType().cast<BufferType>(), lowering));
|
||||
auto elementPtrType = getPtrToElementType(
|
||||
allocOp.getResult()->getType().cast<BufferType>(), lowering);
|
||||
auto bufferDescriptorType =
|
||||
convertLinalgType(allocOp.getResult()->getType(), lowering);
|
||||
|
||||
|
@ -221,9 +213,8 @@ public:
|
|||
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto voidPtrTy = LLVM::LLVMType::get(
|
||||
op->getContext(),
|
||||
llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
|
||||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto *module = op->getFunction()->getModule();
|
||||
Function *freeFunc = module->getNamedFunction("free");
|
||||
|
@ -235,8 +226,8 @@ public:
|
|||
|
||||
// Get MLIR types for extracting element pointer.
|
||||
auto deallocOp = cast<BufferDeallocOp>(op);
|
||||
auto elementPtrTy = rewriter.getType<LLVMType>(getPtrToElementType(
|
||||
deallocOp.getOperand()->getType().cast<BufferType>(), lowering));
|
||||
auto elementPtrTy = getPtrToElementType(
|
||||
deallocOp.getOperand()->getType().cast<BufferType>(), lowering);
|
||||
|
||||
// Emit MLIR for buffer_dealloc.
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
|
@ -298,8 +289,7 @@ public:
|
|||
ArrayRef<Value *> indices,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loadOp = cast<Op>(op);
|
||||
auto elementTy = rewriter.getType<LLVMType>(
|
||||
getPtrToElementType(loadOp.getViewType(), lowering));
|
||||
auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
auto pos = [&rewriter](ArrayRef<int> values) {
|
||||
return positionAttr(rewriter, values);
|
||||
|
@ -425,8 +415,7 @@ public:
|
|||
// Helper function to obtain the ptr of the given `view`.
|
||||
auto getViewPtr = [pos, &rewriter, this](ViewType type,
|
||||
Value *view) -> Value * {
|
||||
auto elementPtrTy =
|
||||
rewriter.getType<LLVMType>(getPtrToElementType(type, lowering));
|
||||
auto elementPtrTy = getPtrToElementType(type, lowering);
|
||||
return extractvalue(elementPtrTy, view, pos(0));
|
||||
};
|
||||
|
||||
|
@ -512,8 +501,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
auto viewOp = cast<ViewOp>(op);
|
||||
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
|
||||
auto elementTy = rewriter.getType<LLVMType>(
|
||||
getPtrToElementType(viewOp.getViewType(), lowering));
|
||||
auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
|
||||
auto pos = [&rewriter](ArrayRef<int> values) {
|
||||
|
|
Loading…
Reference in New Issue