[mlir][bufferize] Improve to_tensor/to_memref folding

Differential Revision: https://reviews.llvm.org/D128615
This commit is contained in:
Matthias Springer 2022-06-27 21:34:09 +02:00
parent 11b414463d
commit cb47124179
4 changed files with 26 additions and 22 deletions

View File

@ -55,8 +55,7 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
/// to_memref op are different, a memref.cast is needed. /// to_memref op are different, a memref.cast is needed.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
ToMemrefOp toMemref, ToMemrefOp toMemref);
bool allowSameType = true);
} // namespace bufferization } // namespace bufferization
} // namespace mlir } // namespace mlir

View File

@ -84,8 +84,9 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
/// to_memref op are different, a memref.cast is needed. /// to_memref op are different, a memref.cast is needed.
LogicalResult mlir::bufferization::foldToMemrefToTensorPair( LogicalResult
RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
ToMemrefOp toMemref) {
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>(); auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
if (!memrefToTensor) if (!memrefToTensor)
return failure(); return failure();
@ -95,9 +96,6 @@ LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
// Directly rewrite if the type did not change. // Directly rewrite if the type did not change.
if (srcType == destType) { if (srcType == destType) {
// Function can be configured to only handle cases where a cast is needed.
if (!allowSameType)
return failure();
rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
return success(); return success();
} }
@ -541,6 +539,19 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
} }
namespace { namespace {
/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
using OpRewritePattern<ToTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
PatternRewriter &rewriter) const final {
auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
if (!toMemrefOp)
return failure();
rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
return success();
}
};
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern<tensor::DimOp>::OpRewritePattern; using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
@ -556,12 +567,11 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
return success(); return success();
} }
}; };
} // namespace } // namespace
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { MLIRContext *context) {
results.add<DimOfToTensorFolder>(context); results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -601,17 +611,14 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
} }
}; };
/// Canonicalize bufferization.to_tensor + bufferization.to_memref to /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. /// cast if necessary.
struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
using OpRewritePattern<ToMemrefOp>::OpRewritePattern; using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToMemrefOp toMemref, LogicalResult matchAndRewrite(ToMemrefOp toMemref,
PatternRewriter &rewriter) const final { PatternRewriter &rewriter) const final {
// Only handle cases where a cast is needed. The other case is handled by return foldToMemrefToTensorPair(rewriter, toMemref);
// the folder.
return foldToMemrefToTensorPair(rewriter, toMemref,
/*allowSameType=*/false);
} }
}; };
@ -651,8 +658,8 @@ struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { MLIRContext *context) {
results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
context); ToMemrefToTensorFolding>(context);
} }
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,

View File

@ -787,8 +787,7 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
} }
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32> // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
// CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32> // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
} }

View File

@ -109,8 +109,7 @@
// CHECK: scf.yield %[[VAL_84]] : f64 // CHECK: scf.yield %[[VAL_84]] : f64
// CHECK: } // CHECK: }
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64> // CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64> // CHECK: return %[[VAL_0]] : tensor<f64>
// CHECK: return %[[VAL_87]] : tensor<f64>
// CHECK: } // CHECK: }
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true}, func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
%arga: tensor<64x32xf64, #SparseMatrix>, %arga: tensor<64x32xf64, #SparseMatrix>,