From 1b579d998ad5a45282e8daaf0bac26df3d3c1f29 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 12 Dec 2019 09:56:12 -0800 Subject: [PATCH] [Linalg] Add test for fusion of GenericOp with IndexedGenericOp. PiperOrigin-RevId: 285211797 --- mlir/test/Dialect/Linalg/fusion.mlir | 57 +++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir index cbb99a766732..ba74813f566b 100644 --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -3,7 +3,10 @@ #map0 = (d0) -> (d0 + 2) #map1 = (d0) -> (d0 + 4) #map2 = (d0) -> (d0 + 3) - +#map3 = (d0)[s0, s1] -> (d0 * s1 + s0) +#map4 = (d0) -> (d0) +#map5 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2) +#map6 = (d0, d1) -> (d0, d1) // CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1) func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { @@ -384,3 +387,55 @@ func @pointwise_no_view(%M: index, %N: index) { // CHECK: addf // CHECK: linalg.generic // CHECK: mulf + +func @indexed_generic_test(%A: memref, + %B: memref, + %C: memref, + %D: memref) { + linalg.generic #pointwise_2d_trait %A, %B, %C { + ^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors + %2 = addf %e, %arg5 : f32 + linalg.yield %2 : f32 + }: memref, memref, memref + %c1 = constant 1 : index + %c0 = constant 0 : index + %c25 = constant 25 : index + %c10 = constant 10 : index + %0 = dim %C, 0 : memref + %1 = dim %C, 1 : memref + %2 = dim %D, 0 : memref + %3 = dim %D, 1 : memref + loop.for %arg2 = %c0 to %0 step %c10 { + loop.for %arg3 = %c0 to %1 step %c25 { + %4 = std.subview %C[%arg2, %arg3][%c10, %c25][%c1, %c1] : + memref to memref + %5 = std.subview %D[%arg2, %arg3][%c10, %c25][%c1, %c1] : + memref to memref + linalg.indexed_generic { + indexing_maps = [#map6, #map6], + iterator_types = ["parallel", "parallel"], + args_in = 1, + args_out = 1 + } %4, %5 { + ^bb0(%arg4: index, %arg5: index, %arg6: f32, %arg7: f32): + %6 = addi %arg4, %arg2 : index + %7 = addi %arg5, %arg3 : index + %8 = index_cast %6 : index to i32 + %9 = sitofp %8 : i32 to f32 + %10 = index_cast %7 : index to i32 + %11 = sitofp %10 : i32 to f32 + %12 = addf %9, %11 : f32 + linalg.yield %12 : f32 + }: memref, memref + } + } + return +} +// CHECK-LABEL: func @indexed_generic_test +// CHECK: loop.for +// CHECK: loop.for +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.indexed_generic +// CHECK: index_cast