[mlir] ConvertStandardToLLVM: make AllocLikeOpLowering public

It is useful for someone who wants to implement custom AllocOp LLVM lowering

Differential Revision: https://reviews.llvm.org/D102932
This commit is contained in:
Butygin 2021-05-19 22:04:29 +03:00
parent 75cc1cf018
commit 9afbca746b
2 changed files with 104 additions and 92 deletions

View File

@ -606,6 +606,59 @@ private:
using ConvertToLLVMPattern::matchAndRewrite;
};
/// Lowering for AllocOp and AllocaOp.
struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
using ConvertToLLVMPattern::createIndexConstant;
using ConvertToLLVMPattern::getIndexType;
using ConvertToLLVMPattern::getVoidPtrType;
explicit AllocLikeOpLLVMLowering(StringRef opName,
LLVMTypeConverter &converter)
: ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
protected:
// Returns 'input' aligned up to 'alignment'. Computes
// bumped = input + alignement - 1
// aligned = bumped - bumped % alignment
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
Value input, Value alignment);
/// Allocates the underlying buffer. Returns the allocated pointer and the
/// aligned pointer.
virtual std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
Value sizeBytes, Operation *op) const = 0;
private:
static MemRefType getMemRefResultType(Operation *op) {
return op->getResult(0).getType().cast<MemRefType>();
}
LogicalResult match(Operation *op) const override {
MemRefType memRefType = getMemRefResultType(op);
return success(isConvertibleAndHasIdentityMaps(memRefType));
}
// An `alloc` is converted into a definition of a memref descriptor value and
// a call to `malloc` to allocate the underlying data buffer. The memref
// descriptor is of the LLVM structure type where:
// 1. the first element is a pointer to the allocated (typed) data buffer,
// 2. the second element is a pointer to the (typed) payload, aligned to the
// specified alignment,
// 3. the remaining elements serve to store all the sizes and strides of the
// memref using LLVM-converted `index` type.
//
// Alignment is performed by allocating `alignment` more bytes than
// requested and shifting the aligned pointer relative to the allocated
// memory. Note: `alignment - <minimum malloc alignment>` would actually be
// sufficient. If alignment is unspecified, the two pointers are equal.
// An `alloca` is converted into a definition of a memref descriptor value and
// an llvm.alloca to allocate the underlying data buffer.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
namespace LLVM {
namespace detail {
/// Replaces the given operation "op" with a new operation of type "targetOp"

View File

@ -1831,92 +1831,10 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
}
};
/// Lowering for AllocOp and AllocaOp.
struct AllocLikeOpLowering : public ConvertToLLVMPattern {
using ConvertToLLVMPattern::createIndexConstant;
using ConvertToLLVMPattern::getIndexType;
using ConvertToLLVMPattern::getVoidPtrType;
explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter)
: ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
protected:
// Returns 'input' aligned up to 'alignment'. Computes
// bumped = input + alignement - 1
// aligned = bumped - bumped % alignment
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
Value input, Value alignment) {
Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
}
/// Allocates the underlying buffer. Returns the allocated pointer and the
/// aligned pointer.
virtual std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
Value sizeBytes, Operation *op) const = 0;
private:
static MemRefType getMemRefResultType(Operation *op) {
return op->getResult(0).getType().cast<MemRefType>();
}
LogicalResult match(Operation *op) const override {
MemRefType memRefType = getMemRefResultType(op);
return success(isConvertibleAndHasIdentityMaps(memRefType));
}
// An `alloc` is converted into a definition of a memref descriptor value and
// a call to `malloc` to allocate the underlying data buffer. The memref
// descriptor is of the LLVM structure type where:
// 1. the first element is a pointer to the allocated (typed) data buffer,
// 2. the second element is a pointer to the (typed) payload, aligned to the
// specified alignment,
// 3. the remaining elements serve to store all the sizes and strides of the
// memref using LLVM-converted `index` type.
//
// Alignment is performed by allocating `alignment` more bytes than
// requested and shifting the aligned pointer relative to the allocated
// memory. Note: `alignment - <minimum malloc alignment>` would actually be
// sufficient. If alignment is unspecified, the two pointers are equal.
// An `alloca` is converted into a definition of a memref descriptor value and
// an llvm.alloca to allocate the underlying data buffer.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefType memRefType = getMemRefResultType(op);
auto loc = op->getLoc();
// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value sizeBytes;
this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
strides, sizeBytes);
// Allocate the underlying buffer.
Value allocatedPtr;
Value alignedPtr;
std::tie(allocatedPtr, alignedPtr) =
this->allocateBuffer(rewriter, loc, sizeBytes, op);
// Create the MemRef descriptor.
auto memRefDescriptor = this->createMemRefDescriptor(
loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
// Return the final value of the descriptor.
rewriter.replaceOp(op, {memRefDescriptor});
}
};
struct AllocOpLowering : public AllocLikeOpLowering {
struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {}
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
converter) {}
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes,
@ -1967,9 +1885,10 @@ struct AllocOpLowering : public AllocLikeOpLowering {
}
};
struct AlignedAllocOpLowering : public AllocLikeOpLowering {
struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
AlignedAllocOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {}
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
converter) {}
/// Returns the memref's element size in bytes.
// TODO: there are other places where this is used. Expose publicly?
@ -2047,9 +1966,10 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
// Out of line definition, required till C++17.
constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
struct AllocaOpLowering : public AllocLikeOpLowering {
struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
AllocaOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLowering(memref::AllocaOp::getOperationName(), converter) {}
: AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
converter) {}
/// Allocates the underlying buffer using the right call. `allocatedBytePtr`
/// is set to null for stack allocations. `accessAlignment` is set if
@ -2310,10 +2230,10 @@ struct GlobalMemrefOpLowering
/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
/// the first element stashed into the descriptor. This reuses
/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLowering(memref::GetGlobalOp::getOperationName(),
converter) {}
: AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
converter) {}
/// Buffer "allocation" for memref.get_global op is getting the address of
/// the global variable referenced.
@ -4195,6 +4115,45 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
};
} // end namespace
Value AllocLikeOpLLVMLowering::createAligned(
ConversionPatternRewriter &rewriter, Location loc, Value input,
Value alignment) {
Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
}
void AllocLikeOpLLVMLowering::rewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
MemRefType memRefType = getMemRefResultType(op);
auto loc = op->getLoc();
// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value sizeBytes;
this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
strides, sizeBytes);
// Allocate the underlying buffer.
Value allocatedPtr;
Value alignedPtr;
std::tie(allocatedPtr, alignedPtr) =
this->allocateBuffer(rewriter, loc, sizeBytes, op);
// Create the MemRef descriptor.
auto memRefDescriptor = this->createMemRefDescriptor(
loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
// Return the final value of the descriptor.
rewriter.replaceOp(op, {memRefDescriptor});
}
mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
this->addLegalDialect<LLVM::LLVMDialect>();