From 4b644fca08efa32a6f1ff54a281b511c00cf2806 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 16 Mar 2022 04:51:17 +0000 Subject: [PATCH] [mlir][Linalg] Add multi-result op cast test. https://reviews.llvm.org/D121369 fixed an issue with canonicalizing a linalg op producer with a cast op consumer. Adding a test to verify that change. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D121648 --- mlir/test/Dialect/Linalg/canonicalize.mlir | 34 ++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 0e0faab56f6c..eee6ebc90756 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -901,3 +901,37 @@ func @fold_conv_op_with_cast_consumer(%arg0 : tensor, // CHECK-SAME: outs(%[[OUT_CAST]] : // CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[CONV]] // CHECK: return %[[CONV]], %[[RESULT_CAST]] + +// ----- + +func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor) -> (tensor, tensor<2x3x4xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d2 = tensor.dim %arg0, %c2 : tensor + %init1 = linalg.init_tensor [%d1, %d2, %d0] : tensor + %init2 = linalg.init_tensor [%d2, %d1, %d0] : tensor + %0:2 = linalg.generic { + iterator_types = ["parallel", "parallel", "parallel"], + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1, d0)>]} + ins(%arg0 : tensor) outs(%init1, %init2 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32) : + linalg.yield %b0, %b0 : f32, f32 + } -> (tensor, tensor) + %1 = tensor.cast %0#1 : tensor to tensor<2x3x4xf32> + return %0#0, %1 : tensor, tensor<2x3x4xf32> +} +// CHECK: func @fold_multi_use_generic_op_with_consumer +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> +// CHECK-DAG: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor to tensor<4x3x2xf32> +// CHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [3, 2, 4] : tensor<3x2x4xf32> +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[CAST]] : +// CHECK-SAME: outs(%[[INIT2]], %[[INIT1]] : +// CHECK: %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor +// CHECK: return %[[RETURN_CAST]], %[[GENERIC]]#1