[mlir] Support normalizing memrefs with MemRef_ReinterpretCastOp

This patch enables normalizing memrefs with MemRef_ReinterpretCastOp by
adding MemRefsNormalizable trait in the Op definition.

Signed-off-by: Haruki Imai <imaihal@jp.ibm.com>

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D107425
This commit is contained in:
Haruki Imai 2021-08-11 01:15:07 +05:30 committed by Uday Bondhugula
parent 17db125b48
commit b34b1c6955
2 changed files with 20 additions and 1 deletions

View File

@ -778,7 +778,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
def MemRef_ReinterpretCastOp:
BaseOpWithOffsetSizesAndStrides<MemRef_Dialect, "reinterpret_cast", [
NoSideEffect, AttrSizedOperandSegments, ViewLikeOpInterface,
OffsetSizeAndStrideOpInterface
OffsetSizeAndStrideOpInterface, MemRefsNormalizable
]> {
let summary = "memref reinterpret cast operation";
let description = [{

View File

@ -112,3 +112,22 @@ func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14
// Test with an arbitrary op that references the function symbol.
"test.op_funcref"() {func = @test_norm_mix} : () -> ()
// -----
#map_1d_tile = affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)>
// Test with memref.reinterpret_cast
// CHECK-LABEL: test_norm_reinterpret_cast
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x32xf32>) -> memref<3x1x1xf32> {
func @test_norm_reinterpret_cast(%arg0 : memref<3xf32, #map_1d_tile>) -> (memref<3x1x1xf32>) {
%0 = memref.alloc() : memref<3xf32>
"test.op_norm"(%arg0, %0) : (memref<3xf32, #map_1d_tile>, memref<3xf32>) -> ()
%1 = memref.reinterpret_cast %0 to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32>
// CHECK: %[[v0:.*]] = memref.alloc() : memref<3xf32>
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x32xf32>, memref<3xf32>) -> ()
// CHECK: memref.reinterpret_cast %[[v0]] to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32>
return %1 : memref<3x1x1xf32>
}