[mlir] NFC: Expose `getElementPtrType` and `getSizes` methods of AllocOpLowering.

Differential Revision: https://reviews.llvm.org/D84917
This commit is contained in:
Alexander Belyaev 2020-07-30 20:16:51 +02:00
parent 7551fd5ef8
commit 6b8c641d8e
4 changed files with 83 additions and 56 deletions

View File

@ -440,6 +440,20 @@ public:
ValueRange indices, ConversionPatternRewriter &rewriter,
llvm::Module &module) const;
/// Returns the type of a pointer to an element of the memref.
Type getElementPtrType(MemRefType type) const;
/// Determines sizes to be used in the memref descriptor.
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
ArrayRef<Value> dynSizes,
ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &sizes) const;
/// Computes total size in bytes of to store the given shape.
Value getCumulativeSizeInBytes(Location loc, Type elementType,
ArrayRef<Value> shape,
ConversionPatternRewriter &rewriter) const;
protected:
/// Reference to the type converter, with potential extensions.
LLVMTypeConverter &typeConverter;

View File

@ -924,6 +924,52 @@ Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
offset, rewriter);
}
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
auto elementType = type.getElementType();
auto structElementType = typeConverter.convertType(elementType);
return structElementType.cast<LLVM::LLVMType>().getPointerTo(
type.getMemorySpace());
}
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
Location loc, MemRefType memRefType, ArrayRef<Value> dynSizes,
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes) const {
sizes.reserve(memRefType.getRank());
unsigned i = 0;
for (int64_t s : memRefType.getShape())
sizes.push_back(s == ShapedType::kDynamicSize
? dynSizes[i++]
: createIndexConstant(rewriter, loc, s));
}
Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
Location loc, Type elementType, ArrayRef<Value> sizes,
ConversionPatternRewriter &rewriter) const {
// Compute the total number of memref elements.
Value cumulativeSizeInBytes =
sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front();
for (unsigned i = 1, e = sizes.size(); i < e; ++i)
cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, sizes[i]});
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
auto convertedPtrType = typeConverter.convertType(elementType)
.cast<LLVM::LLVMType>()
.getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
auto elementSize =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
return rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
}
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
@ -1693,7 +1739,7 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType,
Value allocatedTypePtr, Value allocatedBytePtr, Value accessAlignment,
uint64_t offset, ArrayRef<int64_t> strides, ArrayRef<Value> sizes) const {
auto elementPtrType = getElementPtrType(memRefType);
auto elementPtrType = this->getElementPtrType(memRefType);
auto structType = typeConverter.convertType(memRefType);
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
@ -1751,52 +1797,6 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
return memRefDescriptor;
}
/// Determines sizes to be used in the memref descriptor.
void getSizes(Location loc, MemRefType memRefType, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &sizes, Value &cumulativeSize,
Value &one) const {
sizes.reserve(memRefType.getRank());
unsigned i = 0;
for (int64_t s : memRefType.getShape())
sizes.push_back(s == -1 ? operands[i++]
: createIndexConstant(rewriter, loc, s));
if (sizes.empty())
sizes.push_back(createIndexConstant(rewriter, loc, 1));
// Compute the total number of memref elements.
cumulativeSize = sizes.front();
for (unsigned i = 1, e = sizes.size(); i < e; ++i)
cumulativeSize = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSize, sizes[i]});
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
auto elementType = memRefType.getElementType();
auto convertedPtrType = typeConverter.convertType(elementType)
.template cast<LLVM::LLVMType>()
.getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
one = createIndexConstant(rewriter, loc, 1);
auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType,
ArrayRef<Value>{nullPtr, one});
auto elementSize =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
cumulativeSize = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSize, elementSize});
}
/// Returns the type of a pointer to an element of the memref.
Type getElementPtrType(MemRefType memRefType) const {
auto elementType = memRefType.getElementType();
auto structElementType = typeConverter.convertType(elementType);
return structElementType.template cast<LLVM::LLVMType>().getPointerTo(
memRefType.getMemorySpace());
}
/// Returns the memref's element size in bytes.
// TODO: there are other places where this is used. Expose publicly?
static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
@ -1851,7 +1851,7 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
MemRefType memRefType, Value one, Value &accessAlignment,
Value &allocatedBytePtr,
ConversionPatternRewriter &rewriter) const {
auto elementPtrType = getElementPtrType(memRefType);
auto elementPtrType = this->getElementPtrType(memRefType);
// With alloca, one gets a pointer to the element type right away.
// For stack allocations.
@ -1954,9 +1954,10 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value, 4> sizes;
Value cumulativeSize, one;
getSizes(loc, memRefType, operands, rewriter, sizes, cumulativeSize, one);
this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes);
Value cumulativeSize = this->getCumulativeSizeInBytes(
loc, memRefType.getElementType(), sizes, rewriter);
// Allocate the underlying buffer.
// Value holding the alignment that has to be performed post allocation
// (in conjunction with allocators that do not support alignment, eg.
@ -1965,8 +1966,9 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
// Byte pointer to the allocated buffer.
Value allocatedBytePtr;
Value allocatedTypePtr =
allocateBuffer(loc, cumulativeSize, op, memRefType, one,
accessAlignment, allocatedBytePtr, rewriter);
allocateBuffer(loc, cumulativeSize, op, memRefType,
createIndexConstant(rewriter, loc, 1), accessAlignment,
allocatedBytePtr, rewriter);
int64_t offset;
SmallVector<int64_t, 4> strides;

View File

@ -36,6 +36,7 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
@ -76,6 +77,7 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
@ -105,6 +107,7 @@ func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x !llvm.float : (!llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
@ -150,6 +153,7 @@ func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
// ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// ALIGNED-ALLOC-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
// ALIGNED-ALLOC-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64
// ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm<"i8*">
// ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">

View File

@ -74,6 +74,7 @@ func @zero_d_alloc() -> memref<f32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
@ -88,6 +89,7 @@ func @zero_d_alloc() -> memref<f32> {
// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*">
// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// BAREPTR-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
@ -126,9 +128,10 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64
// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64
// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one_1]] : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
@ -137,7 +140,7 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64
// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64
// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64
// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %[[allocated]][%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
@ -149,9 +152,10 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// BAREPTR-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64
// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64
// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one_1]] : !llvm.i64
// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*">
// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// BAREPTR-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
@ -160,7 +164,7 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// BAREPTR-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64
// BAREPTR-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64
// BAREPTR-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64
// BAREPTR-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
// BAREPTR-NEXT: %[[aligned:.*]] = llvm.getelementptr %[[allocated]][%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
// BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*">
// BAREPTR-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
@ -182,6 +186,7 @@ func @static_alloc() -> memref<32x18xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
@ -193,6 +198,7 @@ func @static_alloc() -> memref<32x18xf32> {
// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// BAREPTR-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// BAREPTR-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
%0 = alloc() : memref<32x18xf32>
@ -211,6 +217,7 @@ func @static_alloca() -> memref<32x18xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[bytes]] x !llvm.float : (!llvm.i64) -> !llvm<"float*">
%0 = alloca() : memref<32x18xf32>