[mlir][bufferize] Improve to_tensor/to_memref folding

Differential Revision: https://reviews.llvm.org/D128615
This commit is contained in:
Matthias Springer 2022-06-27 21:34:09 +02:00
parent 11b414463d
commit cb47124179
4 changed files with 26 additions and 22 deletions

View File

@ -55,8 +55,7 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
/// to_memref op are different, a memref.cast is needed.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
ToMemrefOp toMemref,
bool allowSameType = true);
ToMemrefOp toMemref);
} // namespace bufferization
} // namespace mlir

View File

@ -84,8 +84,9 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
/// to_memref op are different, a memref.cast is needed.
LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
LogicalResult
mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
ToMemrefOp toMemref) {
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
if (!memrefToTensor)
return failure();
@ -95,9 +96,6 @@ LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
// Directly rewrite if the type did not change.
if (srcType == destType) {
// Function can be configured to only handle cases where a cast is needed.
if (!allowSameType)
return failure();
rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
return success();
}
@ -541,6 +539,19 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
}
namespace {
/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
using OpRewritePattern<ToTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
PatternRewriter &rewriter) const final {
auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
if (!toMemrefOp)
return failure();
rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
return success();
}
};
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
@ -556,12 +567,11 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
return success();
}
};
} // namespace
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfToTensorFolder>(context);
results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
}
//===----------------------------------------------------------------------===//
@ -601,17 +611,14 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
}
};
/// Canonicalize bufferization.to_tensor + bufferization.to_memref to
/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
/// cast if necessary.
struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
PatternRewriter &rewriter) const final {
// Only handle cases where a cast is needed. The other case is handled by
// the folder.
return foldToMemrefToTensorPair(rewriter, toMemref,
/*allowSameType=*/false);
return foldToMemrefToTensorPair(rewriter, toMemref);
}
};
@ -651,8 +658,8 @@ struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
context);
results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
ToMemrefToTensorFolding>(context);
}
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,

View File

@ -787,8 +787,7 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
}
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
// CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
// CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
}

View File

@ -109,8 +109,7 @@
// CHECK: scf.yield %[[VAL_84]] : f64
// CHECK: }
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
// CHECK: return %[[VAL_87]] : tensor<f64>
// CHECK: return %[[VAL_0]] : tensor<f64>
// CHECK: }
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
%arga: tensor<64x32xf64, #SparseMatrix>,