forked from OSchip/llvm-project
[mlir][bufferize] Improve to_tensor/to_memref folding
Differential Revision: https://reviews.llvm.org/D128615
This commit is contained in:
parent
11b414463d
commit
cb47124179
|
@ -55,8 +55,7 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
|
||||||
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
|
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
|
||||||
/// to_memref op are different, a memref.cast is needed.
|
/// to_memref op are different, a memref.cast is needed.
|
||||||
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
|
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
|
||||||
ToMemrefOp toMemref,
|
ToMemrefOp toMemref);
|
||||||
bool allowSameType = true);
|
|
||||||
|
|
||||||
} // namespace bufferization
|
} // namespace bufferization
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -84,8 +84,9 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
|
||||||
|
|
||||||
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
|
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
|
||||||
/// to_memref op are different, a memref.cast is needed.
|
/// to_memref op are different, a memref.cast is needed.
|
||||||
LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
|
LogicalResult
|
||||||
RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
|
mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
|
||||||
|
ToMemrefOp toMemref) {
|
||||||
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
|
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
|
||||||
if (!memrefToTensor)
|
if (!memrefToTensor)
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -95,9 +96,6 @@ LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
|
||||||
|
|
||||||
// Directly rewrite if the type did not change.
|
// Directly rewrite if the type did not change.
|
||||||
if (srcType == destType) {
|
if (srcType == destType) {
|
||||||
// Function can be configured to only handle cases where a cast is needed.
|
|
||||||
if (!allowSameType)
|
|
||||||
return failure();
|
|
||||||
rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
|
rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -541,6 +539,19 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
|
||||||
|
struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
|
||||||
|
using OpRewritePattern<ToTensorOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
|
||||||
|
PatternRewriter &rewriter) const final {
|
||||||
|
auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
|
||||||
|
if (!toMemrefOp)
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
|
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
|
||||||
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
||||||
|
@ -556,12 +567,11 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.add<DimOfToTensorFolder>(context);
|
results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -601,17 +611,14 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Canonicalize bufferization.to_tensor + bufferization.to_memref to
|
/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
|
||||||
/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
|
/// cast if necessary.
|
||||||
struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
|
struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
|
||||||
using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
|
using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
|
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
|
||||||
PatternRewriter &rewriter) const final {
|
PatternRewriter &rewriter) const final {
|
||||||
// Only handle cases where a cast is needed. The other case is handled by
|
return foldToMemrefToTensorPair(rewriter, toMemref);
|
||||||
// the folder.
|
|
||||||
return foldToMemrefToTensorPair(rewriter, toMemref,
|
|
||||||
/*allowSameType=*/false);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -651,8 +658,8 @@ struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
|
||||||
|
|
||||||
void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
|
results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
|
||||||
context);
|
ToMemrefToTensorFolding>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
|
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
|
||||||
|
|
|
@ -787,8 +787,7 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
|
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
|
||||||
// CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
|
// CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
||||||
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
|
||||||
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -109,8 +109,7 @@
|
||||||
// CHECK: scf.yield %[[VAL_84]] : f64
|
// CHECK: scf.yield %[[VAL_84]] : f64
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
|
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
|
||||||
// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
|
// CHECK: return %[[VAL_0]] : tensor<f64>
|
||||||
// CHECK: return %[[VAL_87]] : tensor<f64>
|
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
|
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
|
||||||
%arga: tensor<64x32xf64, #SparseMatrix>,
|
%arga: tensor<64x32xf64, #SparseMatrix>,
|
||||||
|
|
Loading…
Reference in New Issue