[mlir][bufferization] Do not canonicalize to_tensor(to_memref(x))

This is a partial revert of D128615.

to_memref(to_tensor(x)) always be folded to x. But to_tensor(to_memref(x)) cannot be folded in the general case because writes to the intermediary memref may go unnoticed.

Differential Revision: https://reviews.llvm.org/D129354
This commit is contained in:
Matthias Springer 2022-07-09 09:15:36 +02:00
parent e1272ab6ec
commit fc9b37dd53
3 changed files with 5 additions and 17 deletions

View File

@ -539,20 +539,6 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
}
namespace {
/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
using OpRewritePattern<ToTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
PatternRewriter &rewriter) const final {
auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
if (!toMemrefOp)
return failure();
rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
return success();
}
};
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
@ -571,7 +557,7 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
results.add<DimOfToTensorFolder>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -787,7 +787,8 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
}
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
// CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
// CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
}

View File

@ -109,7 +109,8 @@
// CHECK: scf.yield %[[VAL_84]] : f64
// CHECK: }
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
// CHECK: return %[[VAL_0]] : tensor<f64>
// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
// CHECK: return %[[VAL_87]] : tensor<f64>
// CHECK: }
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
%arga: tensor<64x32xf64, #SparseMatrix>,