From 500d4c45ba7f31907a64dead8ddb292649e6ce75 Mon Sep 17 00:00:00 2001 From: cwz920716 Date: Wed, 15 Sep 2021 02:59:18 +0000 Subject: [PATCH] [MLIR] Use memref.copy ops in BufferResultsToOutParams pass. Both copy/alloc ops are using memref dialect after this change. Reviewed By: silvas, mehdi_amini Differential Revision: https://reviews.llvm.org/D109480 --- mlir/include/mlir/Transforms/Passes.td | 2 +- mlir/lib/Transforms/BufferResultsToOutParams.cpp | 3 +-- mlir/lib/Transforms/CMakeLists.txt | 1 - mlir/lib/Transforms/PassDetail.h | 4 ---- mlir/test/Transforms/buffer-results-to-out-params.mlir | 10 +++++----- 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 45d72c061d6e..91af2a2c56a9 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -352,7 +352,7 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp"> works for static shaped memrefs. }]; let constructor = "mlir::createBufferResultsToOutParamsPass()"; - let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; + let dependentDialects = ["memref::MemRefDialect"]; } def Canonicalizer : Pass<"canonicalize"> { diff --git a/mlir/lib/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Transforms/BufferResultsToOutParams.cpp index 0920d1321e42..73cc073b496b 100644 --- a/mlir/lib/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Transforms/BufferResultsToOutParams.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Operation.h" @@ -71,7 +70,7 @@ static void updateReturnOps(FuncOp func, } OpBuilder builder(op); for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) - builder.create(op.getLoc(), std::get<0>(t), + builder.create(op.getLoc(), std::get<0>(t), std::get<1>(t)); builder.create(op.getLoc(), keepAsReturnOperands); op.erase(); diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 99133af8b981..54f3693c89c6 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -33,7 +33,6 @@ add_mlir_library(MLIRTransforms MLIRAffine MLIRAnalysis MLIRCopyOpInterface - MLIRLinalg MLIRLoopLikeInterface MLIRMemRef MLIRSCF diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h index 0f998a7147ce..2cb0e12b1cf2 100644 --- a/mlir/lib/Transforms/PassDetail.h +++ b/mlir/lib/Transforms/PassDetail.h @@ -18,10 +18,6 @@ class AffineDialect; template void registerDialect(DialectRegistry ®istry); -namespace linalg { -class LinalgDialect; -} // end namespace linalg - namespace memref { class MemRefDialect; } // end namespace memref diff --git a/mlir/test/Transforms/buffer-results-to-out-params.mlir b/mlir/test/Transforms/buffer-results-to-out-params.mlir index cac3e7461225..063d0d39f5c1 100644 --- a/mlir/test/Transforms/buffer-results-to-out-params.mlir +++ b/mlir/test/Transforms/buffer-results-to-out-params.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @basic( // CHECK-SAME: %[[ARG:.*]]: memref) { // CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref -// CHECK: linalg.copy(%[[RESULT]], %[[ARG]]) : memref, memref +// CHECK: memref.copy %[[RESULT]], %[[ARG]] : memref to memref // CHECK: return // CHECK: } func @basic() -> (memref) { @@ -15,7 +15,7 @@ func @basic() -> (memref) { // CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>, // CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) { // CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<2xf32> -// CHECK: linalg.copy(%[[RESULT]], %[[ARG1]]) : memref<2xf32>, memref<2xf32> +// CHECK: memref.copy %[[RESULT]], %[[ARG1]] : memref<2xf32> to memref<2xf32> // CHECK: return // CHECK: } func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) { @@ -27,8 +27,8 @@ func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) { // CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>, // CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) { // CHECK: %[[RESULTS:.*]]:2 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>) -// CHECK: linalg.copy(%[[RESULTS]]#0, %[[ARG0]]) : memref<1xf32>, memref<1xf32> -// CHECK: linalg.copy(%[[RESULTS]]#1, %[[ARG1]]) : memref<2xf32>, memref<2xf32> +// CHECK: memref.copy %[[RESULTS]]#0, %[[ARG0]] : memref<1xf32> to memref<1xf32> +// CHECK: memref.copy %[[RESULTS]]#1, %[[ARG1]] : memref<2xf32> to memref<2xf32> // CHECK: return // CHECK: } func @multiple_results() -> (memref<1xf32>, memref<2xf32>) { @@ -39,7 +39,7 @@ func @multiple_results() -> (memref<1xf32>, memref<2xf32>) { // CHECK-LABEL: func @non_memref_types( // CHECK-SAME: %[[OUTPARAM:.*]]: memref) -> (i1, i32) { // CHECK: %[[RESULT1:.*]]:3 = "test.source"() : () -> (i1, memref, i32) -// CHECK: linalg.copy(%[[RESULT1]]#1, %[[OUTPARAM]]) : memref, memref +// CHECK: memref.copy %[[RESULT1]]#1, %[[OUTPARAM]] : memref to memref // CHECK: return %[[RESULT1]]#0, %[[RESULT1]]#2 : i1, i32 // CHECK: } func @non_memref_types() -> (i1, memref, i32) {