[mlir][MemRef] Make sure types match when folding dim(reshape)

Reshape can take integer types in addition to index, but dim always
returns index.

Differential Revision: https://reviews.llvm.org/D104287
This commit is contained in:
Benjamin Kramer 2021-06-15 12:32:16 +02:00
parent 662e074d90
commit cd93935146
2 changed files with 25 additions and 2 deletions

View File

@ -770,8 +770,11 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
rewriter.setInsertionPointAfter(reshape);
rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
llvm::makeArrayRef({dim.index()}));
Location loc = dim.getLoc();
Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
if (load.getType() != dim.getType())
load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
rewriter.replaceOp(dim, load);
return success();
}
};

View File

@ -122,6 +122,26 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
// -----
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_i32(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
// CHECK-NEXT: %[[IDX:.*]] = constant 3
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]]
// CHECK-NOT: memref.dim
// CHECK: return %[[CAST]] : index
func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
-> index {
%c3 = constant 3 : index
%0 = memref.reshape %arg0(%arg1)
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
%1 = memref.dim %0, %c3 : memref<*xf32>
return %1 : index
}
// -----
// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
// CHECK-LABEL: func @fold_dim_of_tensor.cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>