forked from OSchip/llvm-project
[mlir][memref] Implement lowering of memref.copy to llvm
This lowering uses a library call to implement copying in the general case, i.e., supporting arbitrary rank and strided layouts.
This commit is contained in:
parent
4a6bd8e3e7
commit
e939644977
|
@ -45,6 +45,8 @@ LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
|
||||||
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
|
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
|
||||||
Type indexType);
|
Type indexType);
|
||||||
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
|
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
|
||||||
|
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
|
||||||
|
Type unrankedDescriptorType);
|
||||||
|
|
||||||
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
|
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
|
||||||
LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
|
LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
|
||||||
|
|
|
@ -330,6 +330,13 @@ public:
|
||||||
const int64_t *strides;
|
const int64_t *strides;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Small runtime support library for memref.copy lowering during codegen.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
extern "C" MLIR_CRUNNERUTILS_EXPORT void
|
||||||
|
memrefCopy(int64_t elemSize, UnrankedMemRefType<char> *src,
|
||||||
|
UnrankedMemRefType<char> *dst);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Small runtime support library for vector.print lowering during codegen.
|
// Small runtime support library for vector.print lowering during codegen.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -2618,6 +2618,68 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
||||||
|
using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(memref::CopyOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
memref::CopyOp::Adaptor adaptor(operands);
|
||||||
|
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
||||||
|
auto targetType = op.target().getType().cast<BaseMemRefType>();
|
||||||
|
|
||||||
|
// First make sure we have an unranked memref descriptor representation.
|
||||||
|
auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
|
||||||
|
auto rank = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
|
||||||
|
auto *typeConverter = getTypeConverter();
|
||||||
|
auto ptr =
|
||||||
|
typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
|
||||||
|
auto voidPtr =
|
||||||
|
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
|
||||||
|
.getResult();
|
||||||
|
auto unrankedType =
|
||||||
|
UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
|
||||||
|
return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
|
||||||
|
unrankedType,
|
||||||
|
ValueRange{rank, voidPtr});
|
||||||
|
};
|
||||||
|
|
||||||
|
Value unrankedSource = srcType.hasRank()
|
||||||
|
? makeUnranked(adaptor.source(), srcType)
|
||||||
|
: adaptor.source();
|
||||||
|
Value unrankedTarget = targetType.hasRank()
|
||||||
|
? makeUnranked(adaptor.target(), targetType)
|
||||||
|
: adaptor.target();
|
||||||
|
|
||||||
|
// Now promote the unranked descriptors to the stack.
|
||||||
|
auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
|
||||||
|
rewriter.getIndexAttr(1));
|
||||||
|
auto promote = [&](Value desc) {
|
||||||
|
auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
|
||||||
|
auto allocated =
|
||||||
|
rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
|
||||||
|
rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
|
||||||
|
return allocated;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto sourcePtr = promote(unrankedSource);
|
||||||
|
auto targetPtr = promote(unrankedTarget);
|
||||||
|
|
||||||
|
auto elemSize = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, getIndexType(),
|
||||||
|
rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
|
||||||
|
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
|
||||||
|
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
|
||||||
|
rewriter.create<LLVM::CallOp>(loc, copyFn,
|
||||||
|
ValueRange{elemSize, sourcePtr, targetPtr});
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
|
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
|
||||||
/// memref type. In unranked case, the fields are extracted from the underlying
|
/// memref type. In unranked case, the fields are extracted from the underlying
|
||||||
/// ranked descriptor.
|
/// ranked descriptor.
|
||||||
|
@ -4009,6 +4071,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
|
||||||
GetGlobalMemrefOpLowering,
|
GetGlobalMemrefOpLowering,
|
||||||
LoadOpLowering,
|
LoadOpLowering,
|
||||||
MemRefCastOpLowering,
|
MemRefCastOpLowering,
|
||||||
|
MemRefCopyOpLowering,
|
||||||
MemRefReinterpretCastOpLowering,
|
MemRefReinterpretCastOpLowering,
|
||||||
MemRefReshapeOpLowering,
|
MemRefReshapeOpLowering,
|
||||||
RankOpLowering,
|
RankOpLowering,
|
||||||
|
|
|
@ -35,6 +35,7 @@ static constexpr llvm::StringRef kPrintNewline = "printNewline";
|
||||||
static constexpr llvm::StringRef kMalloc = "malloc";
|
static constexpr llvm::StringRef kMalloc = "malloc";
|
||||||
static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
|
static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
|
||||||
static constexpr llvm::StringRef kFree = "free";
|
static constexpr llvm::StringRef kFree = "free";
|
||||||
|
static constexpr llvm::StringRef kMemRefCopy = "memref_copy";
|
||||||
|
|
||||||
/// Generic print function lookupOrCreate helper.
|
/// Generic print function lookupOrCreate helper.
|
||||||
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
|
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
|
||||||
|
@ -114,6 +115,15 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
|
||||||
LLVM::LLVMVoidType::get(moduleOp->getContext()));
|
LLVM::LLVMVoidType::get(moduleOp->getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LLVM::LLVMFuncOp
|
||||||
|
mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
|
||||||
|
Type unrankedDescriptorType) {
|
||||||
|
return LLVM::lookupOrCreateFn(
|
||||||
|
moduleOp, kMemRefCopy,
|
||||||
|
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
|
||||||
|
LLVM::LLVMVoidType::get(moduleOp->getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
|
Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
|
||||||
LLVM::LLVMFuncOp fn,
|
LLVM::LLVMFuncOp fn,
|
||||||
ValueRange paramTypes,
|
ValueRange paramTypes,
|
||||||
|
|
|
@ -18,8 +18,10 @@
|
||||||
#include <sys/time.h>
|
#include <sys/time.h>
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
|
#include <alloca.h>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
|
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
|
||||||
|
|
||||||
|
@ -36,6 +38,52 @@ extern "C" void printClose() { fputs(" )", stdout); }
|
||||||
extern "C" void printComma() { fputs(", ", stdout); }
|
extern "C" void printComma() { fputs(", ", stdout); }
|
||||||
extern "C" void printNewline() { fputc('\n', stdout); }
|
extern "C" void printNewline() { fputc('\n', stdout); }
|
||||||
|
|
||||||
|
extern "C" MLIR_CRUNNERUTILS_EXPORT void
|
||||||
|
memrefCopy(int64_t elemSize, UnrankedMemRefType<char> *srcArg,
|
||||||
|
UnrankedMemRefType<char> *dstArg) {
|
||||||
|
DynamicMemRefType<char> src(*srcArg);
|
||||||
|
DynamicMemRefType<char> dst(*dstArg);
|
||||||
|
|
||||||
|
int64_t rank = src.rank;
|
||||||
|
int64_t *indices = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));
|
||||||
|
int64_t *srcStrides = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));
|
||||||
|
int64_t *dstStrides = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));
|
||||||
|
|
||||||
|
char *srcPtr = src.data + src.offset * elemSize;
|
||||||
|
char *dstPtr = dst.data + dst.offset * elemSize;
|
||||||
|
|
||||||
|
// Initialize index and scale strides.
|
||||||
|
for (int rankp = 0; rankp < rank; ++rankp) {
|
||||||
|
indices[rankp] = 0;
|
||||||
|
srcStrides[rankp] = src.strides[rankp] * elemSize;
|
||||||
|
dstStrides[rankp] = dst.strides[rankp] * elemSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t readIndex = 0, writeIndex = 0;
|
||||||
|
for (;;) {
|
||||||
|
// Copy over the element, byte by byte.
|
||||||
|
memcpy(dstPtr + writeIndex, srcPtr + readIndex, elemSize);
|
||||||
|
// Advance index and read position.
|
||||||
|
for (int64_t axis = rank - 1; axis >= 0; --axis) {
|
||||||
|
// Advance at current axis.
|
||||||
|
auto newIndex = ++indices[axis];
|
||||||
|
readIndex += srcStrides[axis];
|
||||||
|
writeIndex += dstStrides[axis];
|
||||||
|
// If this is a valid index, we have our next index, so continue copying.
|
||||||
|
if (src.sizes[axis] != newIndex)
|
||||||
|
break;
|
||||||
|
// We reached the end of this axis. If this is axis 0, we are done.
|
||||||
|
if (axis == 0)
|
||||||
|
return;
|
||||||
|
// Else, reset to 0 and undo the advancement of the linear index that
|
||||||
|
// this axis had. The continue with the axis one outer.
|
||||||
|
indices[axis] = 0;
|
||||||
|
readIndex -= src.sizes[axis] * srcStrides[axis];
|
||||||
|
writeIndex -= dst.sizes[axis] * dstStrides[axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Prints GFLOPS rating.
|
/// Prints GFLOPS rating.
|
||||||
extern "C" void print_flops(double flops) {
|
extern "C" void print_flops(double flops) {
|
||||||
fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9);
|
fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9);
|
||||||
|
|
Loading…
Reference in New Issue