forked from OSchip/llvm-project
[mlir][memref] Implement fast lowering of memref.copy
In the absence of maps, we can lower memref.copy to a memcpy. Differential Revision: https://reviews.llvm.org/D116099
This commit is contained in:
parent
713c2b47a0
commit
ab95ba704d
|
@ -706,12 +706,52 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Pattern to lower a `memref.copy` to llvm.
|
||||
///
|
||||
/// For memrefs with identity layouts, the copy is lowered to the llvm
|
||||
/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
|
||||
/// to the generic `MemrefCopyFn`.
|
||||
struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
||||
using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcType = op.source().getType().dyn_cast<MemRefType>();
|
||||
|
||||
MemRefDescriptor srcDesc(adaptor.source());
|
||||
|
||||
// Compute number of elements.
|
||||
Value numElements;
|
||||
for (int pos = 0; pos < srcType.getRank(); ++pos) {
|
||||
auto size = srcDesc.size(rewriter, loc, pos);
|
||||
numElements = numElements
|
||||
? rewriter.create<LLVM::MulOp>(loc, numElements, size)
|
||||
: size;
|
||||
}
|
||||
// Get element size.
|
||||
auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
|
||||
// Compute total.
|
||||
Value totalSize =
|
||||
rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
|
||||
|
||||
Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
|
||||
MemRefDescriptor targetDesc(adaptor.target());
|
||||
Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
|
||||
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter->convertType(rewriter.getI1Type()),
|
||||
rewriter.getBoolAttr(false));
|
||||
rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
|
||||
isVolatile);
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
||||
auto targetType = op.target().getType().cast<BaseMemRefType>();
|
||||
|
@ -765,6 +805,21 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
||||
auto targetType = op.target().getType().cast<BaseMemRefType>();
|
||||
|
||||
if (srcType.hasRank() &&
|
||||
srcType.cast<MemRefType>().getLayout().isIdentity() &&
|
||||
targetType.hasRank() &&
|
||||
targetType.cast<MemRefType>().getLayout().isIdentity())
|
||||
return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
|
||||
|
||||
return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
|
||||
|
|
|
@ -35,7 +35,7 @@ func @main() -> () {
|
|||
// CHECK-NEXT: [3, 4, 5]
|
||||
|
||||
%copy_two = memref.alloc() : memref<3x2xf32>
|
||||
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
|
||||
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides: [1, 2]
|
||||
: memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
|
||||
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
|
||||
%unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
|
||||
|
@ -49,6 +49,13 @@ func @main() -> () {
|
|||
%copy_empty = memref.alloc() : memref<3x0x1xf32>
|
||||
// Copying an empty shape should do nothing (and should not crash).
|
||||
memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32>
|
||||
|
||||
%input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides: [3, 1, 1]
|
||||
: memref<3x0x1xf32> to memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]>
|
||||
%copy_empty_casted = memref.alloc() : memref<0x3x1xf32>
|
||||
// Copying a casted empty shape should do nothing (and should not crash).
|
||||
memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32>
|
||||
|
||||
memref.dealloc %copy_empty : memref<3x0x1xf32>
|
||||
memref.dealloc %input_empty : memref<3x0x1xf32>
|
||||
memref.dealloc %copy_two : memref<3x2xf32>
|
||||
|
|
Loading…
Reference in New Issue