forked from OSchip/llvm-project
[mlir][memref] Missing type conversion in memref.reshape llvm lowering
Shape can be memref of index type, so memref::LoadOp result need to be converted into llvm type. Differential Revision: https://reviews.llvm.org/D129965
This commit is contained in:
parent
70056d04e2
commit
d4217e6cc8
|
@ -1161,6 +1161,11 @@ private:
|
|||
Value shapeOp = reshapeOp.getShape();
|
||||
Value index = createIndexConstant(rewriter, loc, i);
|
||||
dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
|
||||
Type indexType = getIndexType();
|
||||
if (dimSize.getType() != indexType)
|
||||
dimSize = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, indexType, dimSize);
|
||||
assert(dimSize && "Invalid memref element type");
|
||||
}
|
||||
|
||||
desc.setSize(rewriter, loc, i, dimSize);
|
||||
|
|
|
@ -306,3 +306,35 @@ func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4x
|
|||
return %0 : memref<?x?x12x32xf32>
|
||||
// CHECK: return %[[result_cast]] : memref<?x?x12x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @memref.reshape_index
|
||||
// CHECK-SAME: %[[arg:.*]]: memref<?x?xi32>, %[[shape:.*]]: memref<1xindex>
|
||||
func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>) -> memref<?xi32> {
|
||||
// CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?xi32> to !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<1xindex> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef:.*]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0:.*]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
|
||||
// CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1:.*]][2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
|
||||
// CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64
|
||||
// CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
|
||||
// CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast:.*]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
|
||||
// CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0:.*]] : !llvm.ptr<i64>
|
||||
// CHECK: %[[insert3:.*]] = llvm.insertvalue %[[shape_load0:.*]], %[[insert2:.*]][3, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0:.*]], %[[insert3:.*]][4, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
|
||||
// CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert4:.*]] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)> to memref<?xi32>
|
||||
// CHECK: return %[[result_cast:.*]] : memref<?xi32>
|
||||
|
||||
%1 = memref.reshape %arg0(%shape) : (memref<?x?xi32>, memref<1xindex>) -> memref<?xi32>
|
||||
return %1 : memref<?xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue