forked from OSchip/llvm-project
[MLIR] Refactor memref type -> LLVM Type conversion
- Eliminate duplicated information about mapping from memref -> its descriptor fields by consolidating that mapping in two functions: getMemRefDescriptorFields and getUnrankedMemRefDescriptorFields. - Change convertMemRefType() and convertUnrankedMemRefType() to use these functions. - Remove convertMemrefSignature and convertUnrankedMemrefSignature. Differential Revision: https://reviews.llvm.org/D90707
This commit is contained in:
parent
63e72aa4f5
commit
8c2025cc61
|
@ -164,12 +164,20 @@ private:
|
||||||
/// Convert a memref type into an LLVM type that captures the relevant data.
|
/// Convert a memref type into an LLVM type that captures the relevant data.
|
||||||
Type convertMemRefType(MemRefType type);
|
Type convertMemRefType(MemRefType type);
|
||||||
|
|
||||||
/// Convert a memref type into a list of non-aggregate LLVM IR types that
|
/// Convert a memref type into a list of LLVM IR types that will form the
|
||||||
/// contain all the relevant data. In particular, the list will contain:
|
/// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
|
||||||
|
/// arrays in the descriptors are unpacked to individual index-typed elements,
|
||||||
|
/// else they are are kept as rank-sized arrays of index type. In particular,
|
||||||
|
/// the list will contain:
|
||||||
/// - two pointers to the memref element type, followed by
|
/// - two pointers to the memref element type, followed by
|
||||||
/// - an integer offset, followed by
|
/// - an index-typed offset, followed by
|
||||||
/// - one integer size per dimension of the memref, followed by
|
/// - (if unpackAggregates = true)
|
||||||
/// - one integer stride per dimension of the memref.
|
/// - one index-typed size per dimension of the memref, followed by
|
||||||
|
/// - one index-typed stride per dimension of the memref.
|
||||||
|
/// - (if unpackArrregates = false)
|
||||||
|
/// - one rank-sized array of index-type for the size of each dimension
|
||||||
|
/// - one rank-sized array of index-type for the stride of each dimension
|
||||||
|
///
|
||||||
/// For example, memref<?x?xf32> is converted to the following list:
|
/// For example, memref<?x?xf32> is converted to the following list:
|
||||||
/// - `!llvm<"float*">` (allocated pointer),
|
/// - `!llvm<"float*">` (allocated pointer),
|
||||||
/// - `!llvm<"float*">` (aligned pointer),
|
/// - `!llvm<"float*">` (aligned pointer),
|
||||||
|
@ -177,17 +185,19 @@ private:
|
||||||
/// - `!llvm.i64`, `!llvm.i64` (sizes),
|
/// - `!llvm.i64`, `!llvm.i64` (sizes),
|
||||||
/// - `!llvm.i64`, `!llvm.i64` (strides).
|
/// - `!llvm.i64`, `!llvm.i64` (strides).
|
||||||
/// These types can be recomposed to a memref descriptor struct.
|
/// These types can be recomposed to a memref descriptor struct.
|
||||||
SmallVector<Type, 5> convertMemRefSignature(MemRefType type);
|
SmallVector<LLVM::LLVMType, 5>
|
||||||
|
getMemRefDescriptorFields(MemRefType type, bool unpackAggregates);
|
||||||
|
|
||||||
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
|
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
|
||||||
/// that contain all the relevant data. In particular, this list contains:
|
/// that will form the unranked memref descriptor. In particular, this list
|
||||||
|
/// contains:
|
||||||
/// - an integer rank, followed by
|
/// - an integer rank, followed by
|
||||||
/// - a pointer to the memref descriptor struct.
|
/// - a pointer to the memref descriptor struct.
|
||||||
/// For example, memref<*xf32> is converted to the following list:
|
/// For example, memref<*xf32> is converted to the following list:
|
||||||
/// !llvm.i64 (rank)
|
/// !llvm.i64 (rank)
|
||||||
/// !llvm<"i8*"> (type-erased pointer).
|
/// !llvm<"i8*"> (type-erased pointer).
|
||||||
/// These types can be recomposed to a unranked memref descriptor struct.
|
/// These types can be recomposed to a unranked memref descriptor struct.
|
||||||
SmallVector<Type, 2> convertUnrankedMemRefSignature();
|
SmallVector<LLVM::LLVMType, 2> getUnrankedMemRefDescriptorFields();
|
||||||
|
|
||||||
// Convert an unranked memref type to an LLVM type that captures the
|
// Convert an unranked memref type to an LLVM type that captures the
|
||||||
// runtime rank and a pointer to the static ranked memref desc
|
// runtime rank and a pointer to the static ranked memref desc
|
||||||
|
|
|
@ -61,14 +61,17 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
|
||||||
Type type,
|
Type type,
|
||||||
SmallVectorImpl<Type> &result) {
|
SmallVectorImpl<Type> &result) {
|
||||||
if (auto memref = type.dyn_cast<MemRefType>()) {
|
if (auto memref = type.dyn_cast<MemRefType>()) {
|
||||||
auto converted = converter.convertMemRefSignature(memref);
|
// In signatures, Memref descriptors are expanded into lists of
|
||||||
|
// non-aggregate values.
|
||||||
|
auto converted =
|
||||||
|
converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
|
||||||
if (converted.empty())
|
if (converted.empty())
|
||||||
return failure();
|
return failure();
|
||||||
result.append(converted.begin(), converted.end());
|
result.append(converted.begin(), converted.end());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
if (type.isa<UnrankedMemRefType>()) {
|
if (type.isa<UnrankedMemRefType>()) {
|
||||||
auto converted = converter.convertUnrankedMemRefSignature();
|
auto converted = converter.getUnrankedMemRefDescriptorFields();
|
||||||
if (converted.empty())
|
if (converted.empty())
|
||||||
return failure();
|
return failure();
|
||||||
result.append(converted.begin(), converted.end());
|
result.append(converted.begin(), converted.end());
|
||||||
|
@ -216,32 +219,6 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
||||||
return converted.getPointerTo();
|
return converted.getPointerTo();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// In signatures, MemRef descriptors are expanded into lists of non-aggregate
|
|
||||||
/// values.
|
|
||||||
SmallVector<Type, 5>
|
|
||||||
LLVMTypeConverter::convertMemRefSignature(MemRefType type) {
|
|
||||||
SmallVector<Type, 5> results;
|
|
||||||
assert(isStrided(type) &&
|
|
||||||
"Non-strided layout maps must have been normalized away");
|
|
||||||
|
|
||||||
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
|
|
||||||
if (!elementType)
|
|
||||||
return {};
|
|
||||||
auto indexTy = getIndexType();
|
|
||||||
|
|
||||||
results.insert(results.begin(), 2,
|
|
||||||
elementType.getPointerTo(type.getMemorySpace()));
|
|
||||||
results.push_back(indexTy);
|
|
||||||
auto rank = type.getRank();
|
|
||||||
results.insert(results.end(), 2 * rank, indexTy);
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
|
|
||||||
/// pointer to descriptor".
|
|
||||||
SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
|
|
||||||
return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function types are converted to LLVM Function types by recursively converting
|
// Function types are converted to LLVM Function types by recursively converting
|
||||||
// argument and result types. If MLIR Function has zero results, the LLVM
|
// argument and result types. If MLIR Function has zero results, the LLVM
|
||||||
|
@ -305,69 +282,92 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
|
||||||
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
|
return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
|
|
||||||
// contains:
|
|
||||||
// 1. the pointer to the data buffer, followed by
|
|
||||||
// 2. a lowered `index`-type integer containing the distance between the
|
|
||||||
// beginning of the buffer and the first element to be accessed through the
|
|
||||||
// view, followed by
|
|
||||||
// 3. an array containing as many `index`-type integers as the rank of the
|
|
||||||
// MemRef: the array represents the size, in number of elements, of the memref
|
|
||||||
// along the given dimension. For constant MemRef dimensions, the
|
|
||||||
// corresponding size entry is a constant whose runtime value must match the
|
|
||||||
// static value, followed by
|
|
||||||
// 4. a second array containing as many `index`-type integers as the rank of
|
|
||||||
// the MemRef: the second array represents the "stride" (in tensor abstraction
|
|
||||||
// sense), i.e. the number of consecutive elements of the underlying buffer.
|
|
||||||
// TODO: add assertions for the static cases.
|
|
||||||
//
|
|
||||||
// template <typename Elem, size_t Rank>
|
|
||||||
// struct {
|
|
||||||
// Elem *allocatedPtr;
|
|
||||||
// Elem *alignedPtr;
|
|
||||||
// int64_t offset;
|
|
||||||
// int64_t sizes[Rank]; // omitted when rank == 0
|
|
||||||
// int64_t strides[Rank]; // omitted when rank == 0
|
|
||||||
// };
|
|
||||||
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
|
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
|
||||||
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
|
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
|
||||||
static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
|
static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
|
||||||
static constexpr unsigned kSizePosInMemRefDescriptor = 3;
|
static constexpr unsigned kSizePosInMemRefDescriptor = 3;
|
||||||
static constexpr unsigned kStridePosInMemRefDescriptor = 4;
|
static constexpr unsigned kStridePosInMemRefDescriptor = 4;
|
||||||
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
|
||||||
int64_t offset;
|
/// Convert a memref type into a list of LLVM IR types that will form the
|
||||||
SmallVector<int64_t, 4> strides;
|
/// memref descriptor. The result contains the following types:
|
||||||
bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset));
|
/// 1. The pointer to the allocated data buffer, followed by
|
||||||
assert(strideSuccess &&
|
/// 2. The pointer to the aligned data buffer, followed by
|
||||||
|
/// 3. A lowered `index`-type integer containing the distance between the
|
||||||
|
/// beginning of the buffer and the first element to be accessed through the
|
||||||
|
/// view, followed by
|
||||||
|
/// 4. An array containing as many `index`-type integers as the rank of the
|
||||||
|
/// MemRef: the array represents the size, in number of elements, of the memref
|
||||||
|
/// along the given dimension. For constant MemRef dimensions, the
|
||||||
|
/// corresponding size entry is a constant whose runtime value must match the
|
||||||
|
/// static value, followed by
|
||||||
|
/// 5. A second array containing as many `index`-type integers as the rank of
|
||||||
|
/// the MemRef: the second array represents the "stride" (in tensor abstraction
|
||||||
|
/// sense), i.e. the number of consecutive elements of the underlying buffer.
|
||||||
|
/// TODO: add assertions for the static cases.
|
||||||
|
///
|
||||||
|
/// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
|
||||||
|
/// are expanded into individual index-type elements.
|
||||||
|
///
|
||||||
|
/// template <typename Elem, typename Index, size_t Rank>
|
||||||
|
/// struct {
|
||||||
|
/// Elem *allocatedPtr;
|
||||||
|
/// Elem *alignedPtr;
|
||||||
|
/// Index offset;
|
||||||
|
/// Index sizes[Rank]; // omitted when rank == 0
|
||||||
|
/// Index strides[Rank]; // omitted when rank == 0
|
||||||
|
/// };
|
||||||
|
SmallVector<LLVM::LLVMType, 5>
|
||||||
|
LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
|
||||||
|
bool unpackAggregates) {
|
||||||
|
assert(isStrided(type) &&
|
||||||
"Non-strided layout maps must have been normalized away");
|
"Non-strided layout maps must have been normalized away");
|
||||||
(void)strideSuccess;
|
|
||||||
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
|
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
|
||||||
if (!elementType)
|
if (!elementType)
|
||||||
return {};
|
return {};
|
||||||
auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
|
auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
|
||||||
auto indexTy = getIndexType();
|
auto indexTy = getIndexType();
|
||||||
|
|
||||||
|
SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
|
||||||
auto rank = type.getRank();
|
auto rank = type.getRank();
|
||||||
if (rank > 0) {
|
if (rank == 0)
|
||||||
auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank());
|
return results;
|
||||||
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy);
|
|
||||||
}
|
if (unpackAggregates)
|
||||||
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
|
results.insert(results.end(), 2 * rank, indexTy);
|
||||||
|
else
|
||||||
|
results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank));
|
||||||
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which
|
/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
|
||||||
// contains:
|
/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
|
||||||
// 1. int64_t rank, the dynamic rank of this MemRef
|
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
||||||
// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
|
// When converting a MemRefType to a struct with descriptor fields, do not
|
||||||
// stack allocated (alloca) copy of a MemRef descriptor that got casted to
|
// unpack the `sizes` and `strides` arrays.
|
||||||
// be unranked.
|
SmallVector<LLVM::LLVMType, 5> types =
|
||||||
|
getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
|
||||||
|
return LLVM::LLVMType::getStructTy(&getContext(), types);
|
||||||
|
}
|
||||||
|
|
||||||
static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
|
static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
|
||||||
static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
|
static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
|
||||||
|
|
||||||
|
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
|
||||||
|
/// that will form the unranked memref descriptor. In particular, the fields
|
||||||
|
/// for an unranked memref descriptor are:
|
||||||
|
/// 1. index-typed rank, the dynamic rank of this MemRef
|
||||||
|
/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
|
||||||
|
/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
|
||||||
|
/// be unranked.
|
||||||
|
SmallVector<LLVM::LLVMType, 2>
|
||||||
|
LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
|
||||||
|
return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
|
||||||
|
}
|
||||||
|
|
||||||
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
|
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
|
||||||
auto rankTy = getIndexType();
|
return LLVM::LLVMType::getStructTy(&getContext(),
|
||||||
auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext());
|
getUnrankedMemRefDescriptorFields());
|
||||||
return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert a memref type to a bare pointer to the memref element type.
|
/// Convert a memref type to a bare pointer to the memref element type.
|
||||||
|
|
Loading…
Reference in New Issue