From fab4b59961aa35109861493dfe071979d56b4360 Mon Sep 17 00:00:00 2001 From: "Arpith C. Jacob" Date: Wed, 5 Aug 2020 12:16:17 +0200 Subject: [PATCH] [mlir] Conversion of ViewOp with memory space to LLVM. Handle the case where the ViewOp takes in a memref that has an memory space. Reviewed By: ftynse, bondhugula, nicolasvasilache Differential Revision: https://reviews.llvm.org/D85048 --- .../StandardToLLVM/StandardToLLVM.cpp | 7 ++++-- .../StandardToLLVM/convert-to-llvmir.mlir | 22 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 533ac629ba5a..2ada7c425600 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2960,8 +2960,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); + auto srcMemRefType = viewOp.source().getType().cast(); Value bitcastPtr = rewriter.create( - loc, targetElementTy.getPointerTo(), allocatedPtr); + loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), + allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. @@ -2969,7 +2971,8 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { alignedPtr = rewriter.create(loc, alignedPtr.getType(), alignedPtr, adaptor.byte_shift()); bitcastPtr = rewriter.create( - loc, targetElementTy.getPointerTo(), alignedPtr); + loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), + alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: The offset in the resulting type must be 0. This is because of diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 6123f68b7e85..9042bf36c1b3 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -824,6 +824,28 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %5 = view %0[%arg2][] : memref<2048xi8> to memref<64x4xf32> + // Test view memory space. + // CHECK: llvm.mlir.constant(2048 : index) : !llvm.i64 + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %6 = alloc() : memref<2048xi8, 4> + + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[BASE_PTR_4:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[SHIFTED_BASE_PTR_4:.*]] = llvm.getelementptr %[[BASE_PTR_4]][%[[ARG2]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK: %[[CAST_SHIFTED_BASE_PTR_4:.*]] = llvm.bitcast %[[SHIFTED_BASE_PTR_4]] : !llvm.ptr to !llvm.ptr + // CHECK: llvm.insertvalue %[[CAST_SHIFTED_BASE_PTR_4]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[C0_4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %[[C0_4]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(64 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %7 = view %6[%arg2][] : memref<2048xi8, 4> to memref<64x4xf32, 4> + return }