forked from OSchip/llvm-project
[mlir][bufferize] Improve to_tensor/to_memref folding
Differential Revision: https://reviews.llvm.org/D128615
This commit is contained in:
parent
11b414463d
commit
cb47124179
|
@ -55,8 +55,7 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
|
|||
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
|
||||
/// to_memref op are different, a memref.cast is needed.
|
||||
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
|
||||
ToMemrefOp toMemref,
|
||||
bool allowSameType = true);
|
||||
ToMemrefOp toMemref);
|
||||
|
||||
} // namespace bufferization
|
||||
} // namespace mlir
|
||||
|
|
|
@ -84,8 +84,9 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
|
|||
|
||||
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
|
||||
/// to_memref op are different, a memref.cast is needed.
|
||||
LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
|
||||
RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
|
||||
LogicalResult
|
||||
mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
|
||||
ToMemrefOp toMemref) {
|
||||
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
|
||||
if (!memrefToTensor)
|
||||
return failure();
|
||||
|
@ -95,9 +96,6 @@ LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
|
|||
|
||||
// Directly rewrite if the type did not change.
|
||||
if (srcType == destType) {
|
||||
// Function can be configured to only handle cases where a cast is needed.
|
||||
if (!allowSameType)
|
||||
return failure();
|
||||
rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
|
||||
return success();
|
||||
}
|
||||
|
@ -541,6 +539,19 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
|
||||
struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
|
||||
using OpRewritePattern<ToTensorOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
|
||||
if (!toMemrefOp)
|
||||
return failure();
|
||||
rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
|
||||
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
||||
|
@ -556,12 +567,11 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<DimOfToTensorFolder>(context);
|
||||
results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -601,17 +611,14 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Canonicalize bufferization.to_tensor + bufferization.to_memref to
|
||||
/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
|
||||
struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
|
||||
/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
|
||||
/// cast if necessary.
|
||||
struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
|
||||
using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
|
||||
PatternRewriter &rewriter) const final {
|
||||
// Only handle cases where a cast is needed. The other case is handled by
|
||||
// the folder.
|
||||
return foldToMemrefToTensorPair(rewriter, toMemref,
|
||||
/*allowSameType=*/false);
|
||||
return foldToMemrefToTensorPair(rewriter, toMemref);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -651,8 +658,8 @@ struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
|
|||
|
||||
void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
|
||||
context);
|
||||
results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
|
||||
ToMemrefToTensorFolding>(context);
|
||||
}
|
||||
|
||||
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
|
||||
|
|
|
@ -787,8 +787,7 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
|
|||
}
|
||||
|
||||
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
|
||||
// CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
|
||||
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
||||
// CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
||||
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -109,8 +109,7 @@
|
|||
// CHECK: scf.yield %[[VAL_84]] : f64
|
||||
// CHECK: }
|
||||
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
|
||||
// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
|
||||
// CHECK: return %[[VAL_87]] : tensor<f64>
|
||||
// CHECK: return %[[VAL_0]] : tensor<f64>
|
||||
// CHECK: }
|
||||
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
|
||||
%arga: tensor<64x32xf64, #SparseMatrix>,
|
||||
|
|
Loading…
Reference in New Issue