[mlir][memref] Implement fast lowering of memref.copy

In the absence of maps, we can lower memref.copy to a memcpy.

Differential Revision: https://reviews.llvm.org/D116099
This commit is contained in:
Stephan Herhut 2022-01-07 10:00:19 +01:00
parent 713c2b47a0
commit ab95ba704d
2 changed files with 65 additions and 3 deletions

View File

@ -706,12 +706,52 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
}
};
/// Pattern to lower a `memref.copy` to llvm.
///
/// For memrefs with identity layouts, the copy is lowered to the llvm
/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
/// to the generic `MemrefCopyFn`.
struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcType = op.source().getType().dyn_cast<MemRefType>();
MemRefDescriptor srcDesc(adaptor.source());
// Compute number of elements.
Value numElements;
for (int pos = 0; pos < srcType.getRank(); ++pos) {
auto size = srcDesc.size(rewriter, loc, pos);
numElements = numElements
? rewriter.create<LLVM::MulOp>(loc, numElements, size)
: size;
}
// Get element size.
auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
// Compute total.
Value totalSize =
rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
MemRefDescriptor targetDesc(adaptor.target());
Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getI1Type()),
rewriter.getBoolAttr(false));
rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
isVolatile);
rewriter.eraseOp(op);
return success();
}
LogicalResult
lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcType = op.source().getType().cast<BaseMemRefType>();
auto targetType = op.target().getType().cast<BaseMemRefType>();
@ -765,6 +805,21 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
return success();
}
LogicalResult
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = op.source().getType().cast<BaseMemRefType>();
auto targetType = op.target().getType().cast<BaseMemRefType>();
if (srcType.hasRank() &&
srcType.cast<MemRefType>().getLayout().isIdentity() &&
targetType.hasRank() &&
targetType.cast<MemRefType>().getLayout().isIdentity())
return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
}
};
/// Extracts allocated, aligned pointers and offset from a ranked or unranked

View File

@ -35,7 +35,7 @@ func @main() -> () {
// CHECK-NEXT: [3, 4, 5]
%copy_two = memref.alloc() : memref<3x2xf32>
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides: [1, 2]
: memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
%unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
@ -49,6 +49,13 @@ func @main() -> () {
%copy_empty = memref.alloc() : memref<3x0x1xf32>
// Copying an empty shape should do nothing (and should not crash).
memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32>
%input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides: [3, 1, 1]
: memref<3x0x1xf32> to memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]>
%copy_empty_casted = memref.alloc() : memref<0x3x1xf32>
// Copying a casted empty shape should do nothing (and should not crash).
memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32>
memref.dealloc %copy_empty : memref<3x0x1xf32>
memref.dealloc %input_empty : memref<3x0x1xf32>
memref.dealloc %copy_two : memref<3x2xf32>