forked from OSchip/llvm-project
[mlir] Convert MemRefReinterpretCastOp to LLVM.
https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15 Differential Revision: https://reviews.llvm.org/D90033
This commit is contained in:
parent
e56e7bd469
commit
d6ab0474c6
|
@ -2416,6 +2416,114 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct MemRefReinterpretCastOpLowering
|
||||
: public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
|
||||
using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto castOp = cast<MemRefReinterpretCastOp>(op);
|
||||
MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary());
|
||||
Type srcType = castOp.source().getType();
|
||||
|
||||
Value descriptor;
|
||||
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
|
||||
adaptor, &descriptor)))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, {descriptor});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
LogicalResult
|
||||
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
|
||||
Type srcType, MemRefReinterpretCastOp castOp,
|
||||
MemRefReinterpretCastOp::Adaptor adaptor,
|
||||
Value *descriptor) const {
|
||||
MemRefType targetMemRefType =
|
||||
castOp.getResult().getType().cast<MemRefType>();
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return failure();
|
||||
|
||||
// Create descriptor.
|
||||
Location loc = castOp.getLoc();
|
||||
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||
|
||||
// Set allocated and aligned pointers.
|
||||
Value allocatedPtr, alignedPtr;
|
||||
extractPointers(loc, rewriter, castOp.source(), adaptor.source(),
|
||||
&allocatedPtr, &alignedPtr);
|
||||
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
||||
desc.setAlignedPtr(rewriter, loc, alignedPtr);
|
||||
|
||||
// Set offset.
|
||||
if (castOp.isDynamicOffset(0))
|
||||
desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
|
||||
else
|
||||
desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
|
||||
|
||||
// Set sizes and strides.
|
||||
unsigned dynSizeId = 0;
|
||||
unsigned dynStrideId = 0;
|
||||
for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
||||
if (castOp.isDynamicSize(i))
|
||||
desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
|
||||
else
|
||||
desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
|
||||
|
||||
if (castOp.isDynamicStride(i))
|
||||
desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
|
||||
else
|
||||
desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
|
||||
}
|
||||
*descriptor = desc;
|
||||
return success();
|
||||
}
|
||||
|
||||
void extractPointers(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value originalOperand, Value convertedOperand,
|
||||
Value *allocatedPtr, Value *alignedPtr) const {
|
||||
Type operandType = originalOperand.getType();
|
||||
if (operandType.isa<MemRefType>()) {
|
||||
MemRefDescriptor desc(convertedOperand);
|
||||
*allocatedPtr = desc.allocatedPtr(rewriter, loc);
|
||||
*alignedPtr = desc.alignedPtr(rewriter, loc);
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned memorySpace =
|
||||
operandType.cast<UnrankedMemRefType>().getMemorySpace();
|
||||
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
|
||||
LLVM::LLVMType llvmElementType =
|
||||
typeConverter.convertType(elementType).cast<LLVM::LLVMType>();
|
||||
LLVM::LLVMType elementPtrPtrType =
|
||||
llvmElementType.getPointerTo(memorySpace).getPointerTo();
|
||||
|
||||
// Extract pointer to the underlying ranked memref descriptor and cast it to
|
||||
// ElemType**.
|
||||
UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
|
||||
Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
|
||||
Value elementPtrPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, elementPtrPtrType, underlyingDescPtr);
|
||||
|
||||
LLVM::LLVMType int32Type =
|
||||
typeConverter.convertType(rewriter.getI32Type()).cast<LLVM::LLVMType>();
|
||||
|
||||
// Extract and set allocated pointer.
|
||||
*allocatedPtr = rewriter.create<LLVM::LoadOp>(loc, elementPtrPtr);
|
||||
|
||||
// Extract and set aligned pointer.
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(1));
|
||||
Value alignedGep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, elementPtrPtrType, elementPtrPtr, ValueRange({one}));
|
||||
*alignedPtr = rewriter.create<LLVM::LoadOp>(loc, alignedGep);
|
||||
}
|
||||
};
|
||||
|
||||
struct DialectCastOpLowering
|
||||
: public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
|
||||
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
|
||||
|
@ -3532,6 +3640,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
|
|||
DimOpLowering,
|
||||
LoadOpLowering,
|
||||
MemRefCastOpLowering,
|
||||
MemRefReinterpretCastOpLowering,
|
||||
RankOpLowering,
|
||||
StoreOpLowering,
|
||||
SubViewOpLowering,
|
||||
|
|
|
@ -432,3 +432,60 @@ func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index {
|
|||
%result = dim %arg, %idx : memref<3x?xf32>
|
||||
return %result : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @memref_reinterpret_cast_ranked_to_static_shape
|
||||
func @memref_reinterpret_cast_ranked_to_static_shape(%input : memref<2x3xf32>) {
|
||||
%output = memref_reinterpret_cast %input to
|
||||
offset: [0], sizes: [6, 1], strides: [1, 1]
|
||||
: memref<2x3xf32> to memref<6x1xf32>
|
||||
return
|
||||
}
|
||||
// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[TY:!.*]]
|
||||
// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY]]
|
||||
// CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[TY]]
|
||||
// CHECK: [[ALIGNED_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[TY]]
|
||||
// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
|
||||
// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
|
||||
// CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]]
|
||||
// CHECK: [[SIZE_0:%.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
|
||||
// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]]
|
||||
// CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_4]][4, 0] : [[TY]]
|
||||
// CHECK: [[STRIDE_0:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_5]][3, 1] : [[TY]]
|
||||
// CHECK: [[STRIDE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]
|
||||
|
||||
// CHECK-LABEL: @memref_reinterpret_cast_unranked_to_dynamic_shape
|
||||
func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
|
||||
%size_0 : index,
|
||||
%size_1 : index,
|
||||
%stride_0 : index,
|
||||
%stride_1 : index,
|
||||
%input : memref<*xf32>) {
|
||||
%output = memref_reinterpret_cast %input to
|
||||
offset: [%offset], sizes: [%size_0, %size_1],
|
||||
strides: [%stride_0, %stride_1]
|
||||
: memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
// CHECK-SAME: ([[OFFSET:%[a-z,0-9]+]]: !llvm.i64,
|
||||
// CHECK-SAME: [[SIZE_0:%[a-z,0-9]+]]: !llvm.i64, [[SIZE_1:%[a-z,0-9]+]]: !llvm.i64,
|
||||
// CHECK-SAME: [[STRIDE_0:%[a-z,0-9]+]]: !llvm.i64, [[STRIDE_1:%[a-z,0-9]+]]: !llvm.i64,
|
||||
// CHECK: [[INPUT:%.*]] = llvm.insertvalue {{.*}}[1] : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]]
|
||||
// CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<float>>
|
||||
// CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr<ptr<float>>
|
||||
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||
// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]]
|
||||
// CHECK-SAME: : (!llvm.ptr<ptr<float>>, !llvm.i32) -> !llvm.ptr<ptr<float>>
|
||||
// CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr<ptr<float>>
|
||||
// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]]
|
||||
// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]]
|
||||
// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]]
|
||||
// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]]
|
||||
// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]]
|
||||
// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]]
|
||||
// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]]
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \
|
||||
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
||||
|
||||
func @main() -> () {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
|
||||
// Initialize input.
|
||||
%input = alloc() : memref<2x3xf32>
|
||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||
%prod = muli %i, %dim_y : index
|
||||
%val = addi %prod, %j : index
|
||||
%val_i64 = index_cast %val : index to i64
|
||||
%val_f32 = sitofp %val_i64 : i64 to f32
|
||||
store %val_f32, %input[%i, %j] : memref<2x3xf32>
|
||||
}
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK-NEXT: [0, 1, 2]
|
||||
// CHECK-NEXT: [3, 4, 5]
|
||||
|
||||
// Test cases.
|
||||
call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
|
||||
call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
|
||||
call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> ()
|
||||
call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) {
|
||||
%output = memref_reinterpret_cast %input to
|
||||
offset: [0], sizes: [6, 1], strides: [1, 1]
|
||||
: memref<2x3xf32> to memref<6x1xf32>
|
||||
|
||||
%unranked_output = memref_cast %output
|
||||
: memref<6x1xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
|
||||
// CHECK-NEXT: [0],
|
||||
// CHECK-NEXT: [1],
|
||||
// CHECK-NEXT: [2],
|
||||
// CHECK-NEXT: [3],
|
||||
// CHECK-NEXT: [4],
|
||||
// CHECK-NEXT: [5]
|
||||
return
|
||||
}
|
||||
|
||||
func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c6 = constant 6 : index
|
||||
%output = memref_reinterpret_cast %input to
|
||||
offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
|
||||
: memref<2x3xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||
|
||||
%unranked_output = memref_cast %output
|
||||
: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
|
||||
// CHECK-NEXT: [0, 1, 2, 3, 4, 5]
|
||||
return
|
||||
}
|
||||
|
||||
func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) {
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
%output = memref_reinterpret_cast %unranked_input to
|
||||
offset: [0], sizes: [6, 1], strides: [1, 1]
|
||||
: memref<*xf32> to memref<6x1xf32>
|
||||
|
||||
%unranked_output = memref_cast %output
|
||||
: memref<6x1xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data =
|
||||
// CHECK-NEXT: [0],
|
||||
// CHECK-NEXT: [1],
|
||||
// CHECK-NEXT: [2],
|
||||
// CHECK-NEXT: [3],
|
||||
// CHECK-NEXT: [4],
|
||||
// CHECK-NEXT: [5]
|
||||
return
|
||||
}
|
||||
|
||||
func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) {
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c6 = constant 6 : index
|
||||
%output = memref_reinterpret_cast %unranked_input to
|
||||
offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1]
|
||||
: memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||
|
||||
%unranked_output = memref_cast %output
|
||||
: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data =
|
||||
// CHECK-NEXT: [0, 1, 2, 3, 4, 5]
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue