diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 7efff9774cd5..6380ff2d8e13 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -45,6 +45,8 @@ LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType); LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, Type indexType); LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, + Type unrankedDescriptorType); /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`. LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name, diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h index fb0b2a65a67e..bd855fcc03a9 100644 --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -330,6 +330,13 @@ public: const int64_t *strides; }; +//===----------------------------------------------------------------------===// +// Small runtime support library for memref.copy lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +memrefCopy(int64_t elemSize, UnrankedMemRefType *src, + UnrankedMemRefType *dst); + //===----------------------------------------------------------------------===// // Small runtime support library for vector.print lowering during codegen. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index db5918e95f18..eb390bf8844f 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2618,6 +2618,68 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { } }; +struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + memref::CopyOp::Adaptor adaptor(operands); + auto srcType = op.source().getType().cast(); + auto targetType = op.target().getType().cast(); + + // First make sure we have an unranked memref descriptor representation. + auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { + auto rank = rewriter.create( + loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); + auto *typeConverter = getTypeConverter(); + auto ptr = + typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); + auto voidPtr = + rewriter.create(loc, getVoidPtrType(), ptr) + .getResult(); + auto unrankedType = + UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); + return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, + unrankedType, + ValueRange{rank, voidPtr}); + }; + + Value unrankedSource = srcType.hasRank() + ? makeUnranked(adaptor.source(), srcType) + : adaptor.source(); + Value unrankedTarget = targetType.hasRank() + ? makeUnranked(adaptor.target(), targetType) + : adaptor.target(); + + // Now promote the unranked descriptors to the stack. + auto one = rewriter.create(loc, getIndexType(), + rewriter.getIndexAttr(1)); + auto promote = [&](Value desc) { + auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); + auto allocated = + rewriter.create(loc, ptrType, ValueRange{one}); + rewriter.create(loc, desc, allocated); + return allocated; + }; + + auto sourcePtr = promote(unrankedSource); + auto targetPtr = promote(unrankedTarget); + + auto elemSize = rewriter.create( + loc, getIndexType(), + rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8)); + auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( + op->getParentOfType(), getIndexType(), sourcePtr.getType()); + rewriter.create(loc, copyFn, + ValueRange{elemSize, sourcePtr, targetPtr}); + rewriter.eraseOp(op); + + return success(); + } +}; + /// Extracts allocated, aligned pointers and offset from a ranked or unranked /// memref type. In unranked case, the fields are extracted from the underlying /// ranked descriptor. @@ -4009,6 +4071,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns( GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, + MemRefCopyOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, RankOpLowering, diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index a43c2251c2d9..47a5851b51f2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -35,6 +35,7 @@ static constexpr llvm::StringRef kPrintNewline = "printNewline"; static constexpr llvm::StringRef kMalloc = "malloc"; static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; static constexpr llvm::StringRef kFree = "free"; +static constexpr llvm::StringRef kMemRefCopy = "memref_copy"; /// Generic print function lookupOrCreate helper. LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name, @@ -114,6 +115,15 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) { LLVM::LLVMVoidType::get(moduleOp->getContext())); } +LLVM::LLVMFuncOp +mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, + Type unrankedDescriptorType) { + return LLVM::lookupOrCreateFn( + moduleOp, kMemRefCopy, + ArrayRef{indexType, unrankedDescriptorType, unrankedDescriptorType}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc, LLVM::LLVMFuncOp fn, ValueRange paramTypes, diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp index e5b682a7b6de..bf96afb73725 100644 --- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp @@ -18,8 +18,10 @@ #include #endif // _WIN32 +#include #include #include +#include #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS @@ -36,6 +38,52 @@ extern "C" void printClose() { fputs(" )", stdout); } extern "C" void printComma() { fputs(", ", stdout); } extern "C" void printNewline() { fputc('\n', stdout); } +extern "C" MLIR_CRUNNERUTILS_EXPORT void +memrefCopy(int64_t elemSize, UnrankedMemRefType *srcArg, + UnrankedMemRefType *dstArg) { + DynamicMemRefType src(*srcArg); + DynamicMemRefType dst(*dstArg); + + int64_t rank = src.rank; + int64_t *indices = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *srcStrides = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *dstStrides = static_cast(alloca(sizeof(int64_t) * rank)); + + char *srcPtr = src.data + src.offset * elemSize; + char *dstPtr = dst.data + dst.offset * elemSize; + + // Initialize index and scale strides. + for (int rankp = 0; rankp < rank; ++rankp) { + indices[rankp] = 0; + srcStrides[rankp] = src.strides[rankp] * elemSize; + dstStrides[rankp] = dst.strides[rankp] * elemSize; + } + + int64_t readIndex = 0, writeIndex = 0; + for (;;) { + // Copy over the element, byte by byte. + memcpy(dstPtr + writeIndex, srcPtr + readIndex, elemSize); + // Advance index and read position. + for (int64_t axis = rank - 1; axis >= 0; --axis) { + // Advance at current axis. + auto newIndex = ++indices[axis]; + readIndex += srcStrides[axis]; + writeIndex += dstStrides[axis]; + // If this is a valid index, we have our next index, so continue copying. + if (src.sizes[axis] != newIndex) + break; + // We reached the end of this axis. If this is axis 0, we are done. + if (axis == 0) + return; + // Else, reset to 0 and undo the advancement of the linear index that + // this axis had. The continue with the axis one outer. + indices[axis] = 0; + readIndex -= src.sizes[axis] * srcStrides[axis]; + writeIndex -= dst.sizes[axis] * dstStrides[axis]; + } + } +} + /// Prints GFLOPS rating. extern "C" void print_flops(double flops) { fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9);