forked from OSchip/llvm-project
[mlir][Linalg] Add a hoistRedundantVectorTransfers helper function
This revision adds a helper function to hoist vector.transfer_read / vector.transfer_write pairs out of immediately enclosing scf::ForOp iteratively, if the following conditions are true: 1. The 2 ops access the same memref with the same indices. 2. All operands are invariant under the enclosing scf::ForOp. 3. No uses of the memref either dominate the transfer_read or are dominated by the transfer_write (i.e. no aliasing between the write and the read across the loop) To improve hoisting opportunities, call the `moveLoopInvariantCode` helper function on the candidate loop above which to hoist. Hoisting the transfers results in scf::ForOp yielding the value that originally transited through memory. This revision additionally exposes `moveLoopInvariantCode` as a helper in LoopUtils.h and updates SliceAnalysis to support return scf::For values and allow hoisting across multiple scf::ForOps. Differential Revision: https://reviews.llvm.org/D81199
This commit is contained in:
parent
9bfdf11807
commit
6953cf6502
|
@ -16,11 +16,25 @@ namespace linalg {
|
|||
|
||||
/// Hoist alloc/dealloc pairs and alloca op out of immediately enclosing
|
||||
/// scf::ForOp if both conditions are true:
|
||||
/// 1. all operands are defined outside the loop.
|
||||
/// 2. all uses are ViewLikeOp or DeallocOp.
|
||||
/// 1. All operands are defined outside the loop.
|
||||
/// 2. All uses are ViewLikeOp or DeallocOp.
|
||||
// TODO: generalize on a per-need basis.
|
||||
void hoistViewAllocOps(FuncOp func);
|
||||
|
||||
/// Hoist vector.transfer_read/vector.transfer_write pairs out of immediately
|
||||
/// enclosing scf::ForOp iteratively, if the following conditions are true:
|
||||
/// 1. The two ops access the same memref with the same indices.
|
||||
/// 2. All operands are invariant under the enclosing scf::ForOp.
|
||||
/// 3. No uses of the memref either dominate the transfer_read or are
|
||||
/// dominated by the transfer_write (i.e. no aliasing between the write and
|
||||
/// the read across the loop)
|
||||
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
|
||||
/// function on the candidate loop above which to hoist. Hoisting the transfers
|
||||
/// results in scf::ForOp yielding the value that originally transited through
|
||||
/// memory.
|
||||
// TODO: generalize on a per-need basis.
|
||||
void hoistRedundantVectorTransfers(FuncOp func);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -22,10 +22,11 @@
|
|||
namespace mlir {
|
||||
class AffineForOp;
|
||||
class FuncOp;
|
||||
class LoopLikeOpInterface;
|
||||
struct MemRefRegion;
|
||||
class OpBuilder;
|
||||
class Value;
|
||||
class ValueRange;
|
||||
struct MemRefRegion;
|
||||
|
||||
namespace scf {
|
||||
class ForOp;
|
||||
|
@ -294,6 +295,9 @@ LogicalResult
|
|||
separateFullTiles(MutableArrayRef<AffineForOp> nest,
|
||||
SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
|
||||
|
||||
/// Move loop invariant code out of `looplike`.
|
||||
LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_LOOP_UTILS_H
|
||||
|
|
|
@ -41,20 +41,24 @@ static void getForwardSliceImpl(Operation *op,
|
|||
}
|
||||
|
||||
if (auto forOp = dyn_cast<AffineForOp>(op)) {
|
||||
for (auto *ownerInst : forOp.getInductionVar().getUsers())
|
||||
if (forwardSlice->count(ownerInst) == 0)
|
||||
getForwardSliceImpl(ownerInst, forwardSlice, filter);
|
||||
for (auto *ownerOp : forOp.getInductionVar().getUsers())
|
||||
if (forwardSlice->count(ownerOp) == 0)
|
||||
getForwardSliceImpl(ownerOp, forwardSlice, filter);
|
||||
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
for (auto *ownerInst : forOp.getInductionVar().getUsers())
|
||||
if (forwardSlice->count(ownerInst) == 0)
|
||||
getForwardSliceImpl(ownerInst, forwardSlice, filter);
|
||||
for (auto *ownerOp : forOp.getInductionVar().getUsers())
|
||||
if (forwardSlice->count(ownerOp) == 0)
|
||||
getForwardSliceImpl(ownerOp, forwardSlice, filter);
|
||||
for (auto result : forOp.getResults())
|
||||
for (auto *ownerOp : result.getUsers())
|
||||
if (forwardSlice->count(ownerOp) == 0)
|
||||
getForwardSliceImpl(ownerOp, forwardSlice, filter);
|
||||
} else {
|
||||
assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
|
||||
assert(op->getNumResults() <= 1 && "unexpected multiple results");
|
||||
if (op->getNumResults() > 0) {
|
||||
for (auto *ownerInst : op->getResult(0).getUsers())
|
||||
if (forwardSlice->count(ownerInst) == 0)
|
||||
getForwardSliceImpl(ownerInst, forwardSlice, filter);
|
||||
for (auto *ownerOp : op->getResult(0).getUsers())
|
||||
if (forwardSlice->count(ownerOp) == 0)
|
||||
getForwardSliceImpl(ownerOp, forwardSlice, filter);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -139,15 +143,15 @@ SetVector<Operation *> mlir::getSlice(Operation *op,
|
|||
SetVector<Operation *> backwardSlice;
|
||||
SetVector<Operation *> forwardSlice;
|
||||
while (currentIndex != slice.size()) {
|
||||
auto *currentInst = (slice)[currentIndex];
|
||||
// Compute and insert the backwardSlice starting from currentInst.
|
||||
auto *currentOp = (slice)[currentIndex];
|
||||
// Compute and insert the backwardSlice starting from currentOp.
|
||||
backwardSlice.clear();
|
||||
getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
|
||||
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
|
||||
slice.insert(backwardSlice.begin(), backwardSlice.end());
|
||||
|
||||
// Compute and insert the forwardSlice starting from currentInst.
|
||||
// Compute and insert the forwardSlice starting from currentOp.
|
||||
forwardSlice.clear();
|
||||
getForwardSlice(currentInst, &forwardSlice, forwardFilter);
|
||||
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
|
||||
slice.insert(forwardSlice.begin(), forwardSlice.end());
|
||||
++currentIndex;
|
||||
}
|
||||
|
|
|
@ -12,10 +12,15 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
|
@ -75,3 +80,96 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
|
||||
func.walk([&](vector::TransferReadOp transferRead) {
|
||||
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
|
||||
<< *transferRead.getOperation() << "\n");
|
||||
auto loop = dyn_cast<scf::ForOp>(transferRead.getParentOp());
|
||||
LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead.getParentOp()
|
||||
<< "\n");
|
||||
if (!loop)
|
||||
return WalkResult::advance();
|
||||
|
||||
if (failed(moveLoopInvariantCode(
|
||||
cast<LoopLikeOpInterface>(loop.getOperation()))))
|
||||
llvm_unreachable(
|
||||
"Unexpected failure to move invariant code out of loop");
|
||||
|
||||
LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
|
||||
<< "\n");
|
||||
|
||||
llvm::SetVector<Operation *> forwardSlice;
|
||||
getForwardSlice(transferRead, &forwardSlice);
|
||||
|
||||
// Look for the last TransferWriteOp in the forwardSlice of
|
||||
// `transferRead` that operates on the same memref.
|
||||
vector::TransferWriteOp transferWrite;
|
||||
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
|
||||
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
|
||||
if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
|
||||
continue;
|
||||
transferWrite = candidateWrite;
|
||||
}
|
||||
|
||||
// All operands of the TransferRead must be defined outside of the loop.
|
||||
for (auto operand : transferRead.getOperands())
|
||||
if (!loop.isDefinedOutsideOfLoop(operand))
|
||||
return WalkResult::advance();
|
||||
|
||||
// Only hoist transfer_read / transfer_write pairs for now.
|
||||
if (!transferWrite)
|
||||
return WalkResult::advance();
|
||||
|
||||
LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
|
||||
<< "\n");
|
||||
|
||||
// Approximate aliasing by checking that:
|
||||
// 1. indices are the same,
|
||||
// 2. no other use either dominates the transfer_read or is dominated
|
||||
// by the transfer_write (i.e. aliasing between the write and the read
|
||||
// across the loop).
|
||||
if (transferRead.indices() != transferWrite.indices())
|
||||
return WalkResult::advance();
|
||||
|
||||
// TODO: may want to memoize this information for performance but it
|
||||
// likely gets invalidated often.
|
||||
DominanceInfo dom(loop);
|
||||
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
|
||||
return WalkResult::advance();
|
||||
for (auto &use : transferRead.memref().getUses())
|
||||
if (dom.properlyDominates(use.getOwner(),
|
||||
transferRead.getOperation()) ||
|
||||
dom.properlyDominates(transferWrite, use.getOwner()))
|
||||
return WalkResult::advance();
|
||||
|
||||
// Hoist read before.
|
||||
if (failed(loop.moveOutOfLoop({transferRead})))
|
||||
llvm_unreachable(
|
||||
"Unexpected failure to move transfer read out of loop");
|
||||
|
||||
// Hoist write after.
|
||||
transferWrite.getOperation()->moveAfter(loop);
|
||||
|
||||
// Rewrite `loop` with new yields by cloning and erase the original loop.
|
||||
OpBuilder b(transferRead);
|
||||
auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
|
||||
transferWrite.vector());
|
||||
|
||||
// Transfer write has been hoisted, need to update the written value to
|
||||
// the value yielded by the newForOp.
|
||||
transferWrite.vector().replaceAllUsesWith(
|
||||
newForOp.getResults().take_back()[0]);
|
||||
|
||||
changed = true;
|
||||
loop.erase();
|
||||
// Need to interrupt and restart because erasing the loop messes up the
|
||||
// walk.
|
||||
return WalkResult::interrupt();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -73,7 +74,7 @@ static bool canBeHoisted(Operation *op,
|
|||
return true;
|
||||
}
|
||||
|
||||
static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike) {
|
||||
LogicalResult mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) {
|
||||
auto &loopBody = looplike.getLoopBody();
|
||||
|
||||
// We use two collections here as we need to preserve the order for insertion
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs -allow-unregistered-dialect | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect | FileCheck %s --check-prefix=VECTOR_TRANSFERS
|
||||
|
||||
// CHECK-LABEL: func @hoist(
|
||||
// CHECK-LABEL: func @hoist_allocs(
|
||||
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
|
||||
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
|
||||
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
|
||||
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
|
||||
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
|
||||
func @hoist(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
|
||||
func @hoist_allocs(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
|
||||
// CHECK-DAG: alloca(%[[VAL]]) : memref<?xi8>
|
||||
// CHECK-DAG: %[[A0:.*]] = alloc(%[[VAL]]) : memref<?xi8>
|
||||
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
||||
|
@ -80,3 +81,69 @@ func @hoist(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
|
|||
// CHECK: dealloc %[[A0]] : memref<?xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs(
|
||||
// VECTOR_TRANSFERS-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
|
||||
// VECTOR_TRANSFERS-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
|
||||
// VECTOR_TRANSFERS-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
|
||||
// VECTOR_TRANSFERS-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
|
||||
// VECTOR_TRANSFERS-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
|
||||
// VECTOR_TRANSFERS-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
|
||||
// VECTOR_TRANSFERS-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
|
||||
// VECTOR_TRANSFERS-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
|
||||
// VECTOR_TRANSFERS-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
|
||||
// VECTOR_TRANSFERS-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
|
||||
// VECTOR_TRANSFERS-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
|
||||
func @hoist_vector_transfer_pairs(
|
||||
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
|
||||
%memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
|
||||
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
|
||||
%c0 = constant 0 : index
|
||||
%cst = constant 0.0 : f32
|
||||
|
||||
// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
|
||||
// VECTOR_TRANSFERS: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
|
||||
// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
|
||||
// VECTOR_TRANSFERS: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
|
||||
// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
|
||||
// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
|
||||
// VECTOR_TRANSFERS: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
|
||||
// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
|
||||
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
|
||||
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
|
||||
// VECTOR_TRANSFERS: "some_use"(%[[MEMREF2]]) : (memref<?x?xf32>) -> vector<3xf32>
|
||||
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
|
||||
// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
|
||||
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
|
||||
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
|
||||
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
|
||||
// VECTOR_TRANSFERS: "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
|
||||
// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
|
||||
// VECTOR_TRANSFERS: }
|
||||
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
|
||||
// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<1xf32>
|
||||
// VECTOR_TRANSFERS: }
|
||||
// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
|
||||
scf.for %i = %lb to %ub step %step {
|
||||
scf.for %j = %lb to %ub step %step {
|
||||
%r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
|
||||
%r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
|
||||
%r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
|
||||
%r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
|
||||
"some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
|
||||
%r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
|
||||
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
|
||||
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
|
||||
%u2 = "some_use"(%memref2) : (memref<?x?xf32>) -> vector<3xf32>
|
||||
%u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
|
||||
%u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
|
||||
vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
|
||||
vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
|
||||
vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
|
||||
vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
|
||||
vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
|
||||
"some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -29,6 +29,10 @@ struct TestLinalgHoisting
|
|||
*this, "test-hoist-view-allocs",
|
||||
llvm::cl::desc("Test hoisting alloc used by view"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testHoistRedundantTransfers{
|
||||
*this, "test-hoist-redundant-transfers",
|
||||
llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -37,6 +41,10 @@ void TestLinalgHoisting::runOnFunction() {
|
|||
hoistViewAllocOps(getFunction());
|
||||
return;
|
||||
}
|
||||
if (testHoistRedundantTransfers) {
|
||||
hoistRedundantVectorTransfers(getFunction());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
|
Loading…
Reference in New Issue