[mlir][bufferize] Fix op filter

Bufferization has an optional filter to exclude certain ops from analysis+bufferization. There were a few remaining places in the codebase where the filter was not checked.

Differential Revision: https://reviews.llvm.org/D125356
This commit is contained in:
Matthias Springer 2022-05-12 09:27:21 +02:00
parent 7b53a45e14
commit 2fe40c34ea
2 changed files with 33 additions and 5 deletions

View File

@ -117,7 +117,7 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
SmallVector<OpOperand *>
AnalysisState::getAliasingOpOperand(OpResult result) const {
if (Operation *op = result.getDefiningOp())
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
return bufferizableOp.getAliasingOpOperand(result, *this);
return {};
}
@ -127,7 +127,7 @@ AnalysisState::getAliasingOpOperand(OpResult result) const {
SmallVector<OpResult>
AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.getAliasingOpResult(opOperand, *this);
return {};
}
@ -136,7 +136,7 @@ AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
/// op is not bufferizable.
bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.
@ -148,7 +148,7 @@ bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
/// `true` if the op is not bufferizable.
bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.
@ -160,7 +160,7 @@ bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
/// alias. Return false if the op is not bufferizable.
bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.

View File

@ -189,3 +189,31 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
// CHECK: return %[[OUT_TENSOR]]
}
// -----
// This is a regression test. The linalg-bufferize pass should ignore all func
// dialect ops.
// CHECK-LABEL: func private @csum(tensor<6xi64>) -> tensor<6xi64>
func.func private @csum(%arg0: tensor<6xi64>) -> tensor<6xi64>
// CHECK: func public @main(%[[arg0:.*]]: tensor<2x3xi1>)
// CHECK: %[[collapse:.*]] = tensor.collapse_shape %[[arg0]]
// CHECK: %[[collapse_m:.*]] = bufferization.to_memref %[[collapse]]
// CHECK: %[[alloc:.*]] = memref.alloc()
// CHECK: linalg.generic {{.*}} ins(%[[collapse_m]] : memref<6xi1>) outs(%[[alloc]] : memref<6xi64>)
// CHECK: %[[generic_t:.*]] = bufferization.to_tensor %[[alloc]]
// CHECK: %[[call:.*]] = call @csum(%[[generic_t]])
// CHECK: return %[[call]]
func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xi1> into tensor<6xi1>
%1 = linalg.init_tensor [6] : tensor<6xi64>
%2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%0 : tensor<6xi1>) outs(%1 : tensor<6xi64>) {
^bb0(%arg1: i1, %arg2: i64):
%4 = arith.extui %arg1 : i1 to i64
linalg.yield %4 : i64
} -> tensor<6xi64>
%3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64>
return %3 : tensor<6xi64>
}