From 9a0af63d05eeec8d333af147f3f1bda1efe63b30 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 7 Jul 2021 08:02:02 +0000 Subject: [PATCH] [mlir][Linalg] Proper handling of ForOp and TiledLoopOp The `bufferizesToMemoryRead` condition was too optimistics in the case of operands that map to a block argument. This is the case for ForOp and TiledLoopOp. For such ops, forward the call to all uses of the matching BBArg. Differential Revision: https://reviews.llvm.org/D105540 --- .../Transforms/ComprehensiveBufferize.cpp | 15 ++++- ...mprehensive-module-bufferize-analysis.mlir | 55 +++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index 23a1fb612a6c..cf4ec5228dd5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -542,12 +542,21 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) { return false; // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. - if (isa(opOperand.getOwner())) + if (auto forOp = dyn_cast(opOperand.getOwner())) { + for (OpOperand &use : + forOp.getRegionIterArgForOpOperand(opOperand).getUses()) + if (bufferizesToMemoryRead(use)) + return true; return false; + } // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. - if (isa(opOperand.getOwner())) + if (auto tiledLoopOp = dyn_cast(opOperand.getOwner())) { + for (OpOperand &use : tiledLoopOp.getTiedBlockArgument(opOperand).getUses()) + if (bufferizesToMemoryRead(use)) + return true; return false; + } // CallOpInterface alone doesn't bufferize to a memory read, one of the uses // of the matching bbArg may. It is the responsibility of the caller to // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be @@ -1685,6 +1694,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp, b.create(forOp.getLoc(), operandBuffer, resultBuffer); } BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); + aliasInfo.createAliasInfoEntry(resultBuffer); + aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); map(bvm, bbArg, resultBuffer); map(bvm, opResult, resultBuffer); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir index a580cbb36060..2dea6fde4f34 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -474,6 +474,61 @@ func @scf_for_with_tensor.insert_slice(%A : tensor, // ----- +func private @some_use(tensor) -> () + +// CHECK-LABEL: func @scf_for_deps +func @scf_for_deps(%A : tensor {linalg.inplaceable = true}, + %B : tensor {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // %r0 must be out of place because one use of %t in the subsequent production + // of %r1 is read. + // CHECK: scf.for + // CHECK-NEXT: scf.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["false"]} + %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + scf.yield %t : tensor + } + + // %r1 bufferizes inplace fine. + // CHECK: scf.for + // CHECK-NEXT: call + // CHECK-NEXT: scf.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["true"]} + %r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + call @some_use(%t) : (tensor) -> () + scf.yield %t : tensor + } + + // %r2 must be out of place because one use of %t in the subsequent production + // of %r3 is read. + // CHECK: linalg.tiled_loop + // CHECK-NEXT: linalg.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["false"]} + %r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step) + ins() + outs(%t = %B: tensor) { + linalg.yield %t : tensor + } + + // %r3 bufferizes inplace fine. + // CHECK: linalg.tiled_loop + // CHECK-NEXT: call + // CHECK-NEXT: linalg.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["true"]} + %r3 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step) + ins() + outs(%t = %B: tensor) { + call @some_use(%t) : (tensor) -> () + linalg.yield %t : tensor + } + + return %r1, %r3: tensor, tensor +} + +// ----- + //===----------------------------------------------------------------------===// // Cross function boundary cases. //===----------------------------------------------------------------------===//