[mlir] Convert MemRefReinterpretCastOp to LLVM.

https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15

Differential Revision: https://reviews.llvm.org/D90033
This commit is contained in:
Alexander Belyaev 2020-10-23 14:34:50 +02:00
parent e56e7bd469
commit d6ab0474c6
3 changed files with 271 additions and 0 deletions

View File

@ -2416,6 +2416,114 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
}
};
struct MemRefReinterpretCastOpLowering
: public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto castOp = cast<MemRefReinterpretCastOp>(op);
MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary());
Type srcType = castOp.source().getType();
Value descriptor;
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
adaptor, &descriptor)))
return failure();
rewriter.replaceOp(op, {descriptor});
return success();
}
private:
LogicalResult
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
Type srcType, MemRefReinterpretCastOp castOp,
MemRefReinterpretCastOp::Adaptor adaptor,
Value *descriptor) const {
MemRefType targetMemRefType =
castOp.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
Location loc = castOp.getLoc();
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
extractPointers(loc, rewriter, castOp.source(), adaptor.source(),
&allocatedPtr, &alignedPtr);
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
desc.setAlignedPtr(rewriter, loc, alignedPtr);
// Set offset.
if (castOp.isDynamicOffset(0))
desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
else
desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
// Set sizes and strides.
unsigned dynSizeId = 0;
unsigned dynStrideId = 0;
for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
if (castOp.isDynamicSize(i))
desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
else
desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
if (castOp.isDynamicStride(i))
desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
else
desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
}
*descriptor = desc;
return success();
}
void extractPointers(Location loc, ConversionPatternRewriter &rewriter,
Value originalOperand, Value convertedOperand,
Value *allocatedPtr, Value *alignedPtr) const {
Type operandType = originalOperand.getType();
if (operandType.isa<MemRefType>()) {
MemRefDescriptor desc(convertedOperand);
*allocatedPtr = desc.allocatedPtr(rewriter, loc);
*alignedPtr = desc.alignedPtr(rewriter, loc);
return;
}
unsigned memorySpace =
operandType.cast<UnrankedMemRefType>().getMemorySpace();
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
LLVM::LLVMType llvmElementType =
typeConverter.convertType(elementType).cast<LLVM::LLVMType>();
LLVM::LLVMType elementPtrPtrType =
llvmElementType.getPointerTo(memorySpace).getPointerTo();
// Extract pointer to the underlying ranked memref descriptor and cast it to
// ElemType**.
UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
Value elementPtrPtr = rewriter.create<LLVM::BitcastOp>(
loc, elementPtrPtrType, underlyingDescPtr);
LLVM::LLVMType int32Type =
typeConverter.convertType(rewriter.getI32Type()).cast<LLVM::LLVMType>();
// Extract and set allocated pointer.
*allocatedPtr = rewriter.create<LLVM::LoadOp>(loc, elementPtrPtr);
// Extract and set aligned pointer.
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));
Value alignedGep = rewriter.create<LLVM::GEPOp>(
loc, elementPtrPtrType, elementPtrPtr, ValueRange({one}));
*alignedPtr = rewriter.create<LLVM::LoadOp>(loc, alignedGep);
}
};
struct DialectCastOpLowering
: public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
@ -3532,6 +3640,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
DimOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
MemRefReinterpretCastOpLowering,
RankOpLowering,
StoreOpLowering,
SubViewOpLowering,

View File

@ -432,3 +432,60 @@ func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index {
%result = dim %arg, %idx : memref<3x?xf32>
return %result : index
}
// CHECK-LABEL: @memref_reinterpret_cast_ranked_to_static_shape
func @memref_reinterpret_cast_ranked_to_static_shape(%input : memref<2x3xf32>) {
%output = memref_reinterpret_cast %input to
offset: [0], sizes: [6, 1], strides: [1, 1]
: memref<2x3xf32> to memref<6x1xf32>
return
}
// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[TY:!.*]]
// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY]]
// CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[TY]]
// CHECK: [[ALIGNED_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[TY]]
// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
// CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]]
// CHECK: [[SIZE_0:%.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]]
// CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_4]][4, 0] : [[TY]]
// CHECK: [[STRIDE_0:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_5]][3, 1] : [[TY]]
// CHECK: [[STRIDE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]
// CHECK-LABEL: @memref_reinterpret_cast_unranked_to_dynamic_shape
func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
%size_0 : index,
%size_1 : index,
%stride_0 : index,
%stride_1 : index,
%input : memref<*xf32>) {
%output = memref_reinterpret_cast %input to
offset: [%offset], sizes: [%size_0, %size_1],
strides: [%stride_0, %stride_1]
: memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
return
}
// CHECK-SAME: ([[OFFSET:%[a-z,0-9]+]]: !llvm.i64,
// CHECK-SAME: [[SIZE_0:%[a-z,0-9]+]]: !llvm.i64, [[SIZE_1:%[a-z,0-9]+]]: !llvm.i64,
// CHECK-SAME: [[STRIDE_0:%[a-z,0-9]+]]: !llvm.i64, [[STRIDE_1:%[a-z,0-9]+]]: !llvm.i64,
// CHECK: [[INPUT:%.*]] = llvm.insertvalue {{.*}}[1] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]]
// CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
// CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr<ptr<float>>
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]]
// CHECK-SAME: : (!llvm.ptr<ptr<float>>, !llvm.i32) -> !llvm.ptr<ptr<float>>
// CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr<ptr<float>>
// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]]
// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]]
// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]]
// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]]
// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]

View File

@ -0,0 +1,105 @@
// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
func @main() -> () {
%c0 = constant 0 : index
%c1 = constant 1 : index
// Initialize input.
%input = alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%prod = muli %i, %dim_y : index
%val = addi %prod, %j : index
%val_i64 = index_cast %val : index to i64
%val_f32 = sitofp %val_i64 : i64 to f32
store %val_f32, %input[%i, %j] : memref<2x3xf32>
}
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK-NEXT: [0, 1, 2]
// CHECK-NEXT: [3, 4, 5]
// Test cases.
call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
return
}
func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) {
%output = memref_reinterpret_cast %input to
offset: [0], sizes: [6, 1], strides: [1, 1]
: memref<2x3xf32> to memref<6x1xf32>
%unranked_output = memref_cast %output
: memref<6x1xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
// CHECK-NEXT: [0],
// CHECK-NEXT: [1],
// CHECK-NEXT: [2],
// CHECK-NEXT: [3],
// CHECK-NEXT: [4],
// CHECK-NEXT: [5]
return
}
func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c6 = constant 6 : index
%output = memref_reinterpret_cast %input to
offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
: memref<2x3xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
%unranked_output = memref_cast %output
: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
// CHECK-NEXT: [0, 1, 2, 3, 4, 5]
return
}
func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) {
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
%output = memref_reinterpret_cast %unranked_input to
offset: [0], sizes: [6, 1], strides: [1, 1]
: memref<*xf32> to memref<6x1xf32>
%unranked_output = memref_cast %output
: memref<6x1xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
// CHECK-NEXT: [0],
// CHECK-NEXT: [1],
// CHECK-NEXT: [2],
// CHECK-NEXT: [3],
// CHECK-NEXT: [4],
// CHECK-NEXT: [5]
return
}
func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
%c0 = constant 0 : index
%c1 = constant 1 : index
%c6 = constant 6 : index
%output = memref_reinterpret_cast %unranked_input to
offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
: memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
%unranked_output = memref_cast %output
: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
// CHECK-NEXT: [0, 1, 2, 3, 4, 5]
return
}