forked from OSchip/llvm-project
[mlir][memref] Implement fast lowering of memref.copy
In the absence of maps, we can lower memref.copy to a memcpy. Differential Revision: https://reviews.llvm.org/D116099
This commit is contained in:
parent
713c2b47a0
commit
ab95ba704d
|
@ -706,12 +706,52 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Pattern to lower a `memref.copy` to llvm.
|
||||||
|
///
|
||||||
|
/// For memrefs with identity layouts, the copy is lowered to the llvm
|
||||||
|
/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
|
||||||
|
/// to the generic `MemrefCopyFn`.
|
||||||
struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
||||||
using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
|
using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
|
lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto srcType = op.source().getType().dyn_cast<MemRefType>();
|
||||||
|
|
||||||
|
MemRefDescriptor srcDesc(adaptor.source());
|
||||||
|
|
||||||
|
// Compute number of elements.
|
||||||
|
Value numElements;
|
||||||
|
for (int pos = 0; pos < srcType.getRank(); ++pos) {
|
||||||
|
auto size = srcDesc.size(rewriter, loc, pos);
|
||||||
|
numElements = numElements
|
||||||
|
? rewriter.create<LLVM::MulOp>(loc, numElements, size)
|
||||||
|
: size;
|
||||||
|
}
|
||||||
|
// Get element size.
|
||||||
|
auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
|
||||||
|
// Compute total.
|
||||||
|
Value totalSize =
|
||||||
|
rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
|
||||||
|
|
||||||
|
Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
|
||||||
|
MemRefDescriptor targetDesc(adaptor.target());
|
||||||
|
Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
|
||||||
|
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, typeConverter->convertType(rewriter.getI1Type()),
|
||||||
|
rewriter.getBoolAttr(false));
|
||||||
|
rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
|
||||||
|
isVolatile);
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
||||||
auto targetType = op.target().getType().cast<BaseMemRefType>();
|
auto targetType = op.target().getType().cast<BaseMemRefType>();
|
||||||
|
@ -765,6 +805,21 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
||||||
|
auto targetType = op.target().getType().cast<BaseMemRefType>();
|
||||||
|
|
||||||
|
if (srcType.hasRank() &&
|
||||||
|
srcType.cast<MemRefType>().getLayout().isIdentity() &&
|
||||||
|
targetType.hasRank() &&
|
||||||
|
targetType.cast<MemRefType>().getLayout().isIdentity())
|
||||||
|
return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
|
||||||
|
|
||||||
|
return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
|
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
|
||||||
|
|
|
@ -35,7 +35,7 @@ func @main() -> () {
|
||||||
// CHECK-NEXT: [3, 4, 5]
|
// CHECK-NEXT: [3, 4, 5]
|
||||||
|
|
||||||
%copy_two = memref.alloc() : memref<3x2xf32>
|
%copy_two = memref.alloc() : memref<3x2xf32>
|
||||||
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
|
%copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides: [1, 2]
|
||||||
: memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
|
: memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
|
||||||
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
|
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
|
||||||
%unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
|
%unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
|
||||||
|
@ -49,6 +49,13 @@ func @main() -> () {
|
||||||
%copy_empty = memref.alloc() : memref<3x0x1xf32>
|
%copy_empty = memref.alloc() : memref<3x0x1xf32>
|
||||||
// Copying an empty shape should do nothing (and should not crash).
|
// Copying an empty shape should do nothing (and should not crash).
|
||||||
memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32>
|
memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32>
|
||||||
|
|
||||||
|
%input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides: [3, 1, 1]
|
||||||
|
: memref<3x0x1xf32> to memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]>
|
||||||
|
%copy_empty_casted = memref.alloc() : memref<0x3x1xf32>
|
||||||
|
// Copying a casted empty shape should do nothing (and should not crash).
|
||||||
|
memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32>
|
||||||
|
|
||||||
memref.dealloc %copy_empty : memref<3x0x1xf32>
|
memref.dealloc %copy_empty : memref<3x0x1xf32>
|
||||||
memref.dealloc %input_empty : memref<3x0x1xf32>
|
memref.dealloc %input_empty : memref<3x0x1xf32>
|
||||||
memref.dealloc %copy_two : memref<3x2xf32>
|
memref.dealloc %copy_two : memref<3x2xf32>
|
||||||
|
|
Loading…
Reference in New Issue