[mlir][memref] Missing type conversion in memref.reshape llvm lowering

Shape can be memref of index type, so memref::LoadOp result need to be converted into llvm type.

Differential Revision: https://reviews.llvm.org/D129965
This commit is contained in:
Ivan Butygin 2022-07-17 18:47:22 +02:00
parent 70056d04e2
commit d4217e6cc8
2 changed files with 37 additions and 0 deletions

View File

@ -1161,6 +1161,11 @@ private:
Value shapeOp = reshapeOp.getShape();
Value index = createIndexConstant(rewriter, loc, i);
dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
Type indexType = getIndexType();
if (dimSize.getType() != indexType)
dimSize = typeConverter->materializeTargetConversion(
rewriter, loc, indexType, dimSize);
assert(dimSize && "Invalid memref element type");
}
desc.setSize(rewriter, loc, i, dimSize);

View File

@ -306,3 +306,35 @@ func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4x
return %0 : memref<?x?x12x32xf32>
// CHECK: return %[[result_cast]] : memref<?x?x12x32xf32>
}
// -----
// CHECK-LABEL: func @memref.reshape_index
// CHECK-SAME: %[[arg:.*]]: memref<?x?xi32>, %[[shape:.*]]: memref<1xindex>
func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>) -> memref<?xi32> {
// CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?xi32> to !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<1xindex> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef:.*]][0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0:.*]][1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1:.*]][2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast:.*]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
// CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0:.*]] : !llvm.ptr<i64>
// CHECK: %[[insert3:.*]] = llvm.insertvalue %[[shape_load0:.*]], %[[insert2:.*]][3, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0:.*]], %[[insert3:.*]][4, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert4:.*]] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)> to memref<?xi32>
// CHECK: return %[[result_cast:.*]] : memref<?xi32>
%1 = memref.reshape %arg0(%shape) : (memref<?x?xi32>, memref<1xindex>) -> memref<?xi32>
return %1 : memref<?xi32>
}