forked from OSchip/llvm-project
[mlir][MemRef] Make sure types match when folding dim(reshape)
Reshape can take integer types in addition to index, but dim always returns index. Differential Revision: https://reviews.llvm.org/D104287
This commit is contained in:
parent
662e074d90
commit
cd93935146
|
@ -770,8 +770,11 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
|
|||
// Place the load directly after the reshape to ensure that the shape memref
|
||||
// was not mutated.
|
||||
rewriter.setInsertionPointAfter(reshape);
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
|
||||
llvm::makeArrayRef({dim.index()}));
|
||||
Location loc = dim.getLoc();
|
||||
Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
|
||||
if (load.getType() != dim.getType())
|
||||
load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
|
||||
rewriter.replaceOp(dim, load);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -122,6 +122,26 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
|
|||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_memref_reshape_i32(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
|
||||
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
|
||||
// CHECK-NEXT: %[[IDX:.*]] = constant 3
|
||||
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
|
||||
// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]]
|
||||
// CHECK-NOT: memref.dim
|
||||
// CHECK: return %[[CAST]] : index
|
||||
func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
|
||||
-> index {
|
||||
%c3 = constant 3 : index
|
||||
%0 = memref.reshape %arg0(%arg1)
|
||||
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
%1 = memref.dim %0, %c3 : memref<*xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
|
||||
// CHECK-LABEL: func @fold_dim_of_tensor.cast
|
||||
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
|
||||
|
|
Loading…
Reference in New Issue