From fc9b37dd532dc68018c0c5947030b34ebcf68d14 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 9 Jul 2022 09:15:36 +0200 Subject: [PATCH] [mlir][bufferization] Do not canonicalize to_tensor(to_memref(x)) This is a partial revert of D128615. to_memref(to_tensor(x)) always be folded to x. But to_tensor(to_memref(x)) cannot be folded in the general case because writes to the intermediary memref may go unnoticed. Differential Revision: https://reviews.llvm.org/D129354 --- .../Bufferization/IR/BufferizationOps.cpp | 16 +--------------- mlir/test/Dialect/SCF/canonicalize.mlir | 3 ++- .../SparseTensor/sparse_vector_chain.mlir | 3 ++- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 35f6f1b6a97f..4ab904ea3930 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -539,20 +539,6 @@ OpFoldResult ToTensorOp::fold(ArrayRef) { } namespace { -/// Canonicalize bufferization.to_tensor + bufferization.to_memref. -struct ToTensorToMemrefFolding : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ToTensorOp toTensorOp, - PatternRewriter &rewriter) const final { - auto toMemrefOp = toTensorOp.getMemref().getDefiningOp(); - if (!toMemrefOp) - return failure(); - rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor()); - return success(); - } -}; - struct DimOfToTensorFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -571,7 +557,7 @@ struct DimOfToTensorFolder : public OpRewritePattern { void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 535a00706100..8e087fc0f38a 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -787,7 +787,8 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>, } // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32> - // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> + // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32> + // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir index f24048e60e07..df55b8373e0e 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -109,7 +109,8 @@ // CHECK: scf.yield %[[VAL_84]] : f64 // CHECK: } // CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref -// CHECK: return %[[VAL_0]] : tensor +// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref +// CHECK: return %[[VAL_87]] : tensor // CHECK: } func.func @sparse_matrix_sum(%argx: tensor {linalg.inplaceable = true}, %arga: tensor<64x32xf64, #SparseMatrix>,