llvm-project/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

198 lines
7.3 KiB
C++
Raw Normal View History

//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functions concerned with hoisting invariant operations
// in the context of Linalg transformations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-hoisting"
#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
using namespace mlir;
using namespace mlir::linalg;
using llvm::dbgs;
void mlir::linalg::hoistViewAllocOps(FuncOp func) {
bool changed = true;
while (changed) {
changed = false;
func.walk([&changed](Operation *op) {
if (!isa<AllocOp, AllocaOp, DeallocOp>(op))
return;
LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n");
auto loop = dyn_cast<scf::ForOp>(op->getParentOp());
LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n");
// Only hoist out of immediately enclosing scf::ForOp.
if (!loop)
return;
// If any operand is defined inside the loop don't hoist.
if (llvm::any_of(op->getOperands(), [&](Value v) {
return !loop.isDefinedOutsideOfLoop(v);
}))
return;
LLVM_DEBUG(DBGS() << "All operands defined outside \n");
// If alloc has other uses than ViewLikeOp and DeallocOp don't hoist.
Value v;
if (op->getNumResults() > 0) {
assert(op->getNumResults() == 1 && "Unexpected multi-result alloc");
v = op->getResult(0);
}
if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) {
return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner());
})) {
LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n");
return;
}
// Move AllocOp before the loop.
if (isa<AllocOp, AllocaOp>(op))
loop.moveOutOfLoop({op});
else // Move DeallocOp outside of the loop.
op->moveAfter(loop);
changed = true;
});
}
}
void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
bool changed = true;
while (changed) {
changed = false;
func.walk([&](vector::TransferReadOp transferRead) {
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
<< *transferRead.getOperation() << "\n");
auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
<< "\n");
if (!loop)
return WalkResult::advance();
if (failed(moveLoopInvariantCode(
cast<LoopLikeOpInterface>(loop.getOperation()))))
llvm_unreachable(
"Unexpected failure to move invariant code out of loop");
LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
<< "\n");
llvm::SetVector<Operation *> forwardSlice;
getForwardSlice(transferRead, &forwardSlice);
// Look for the last TransferWriteOp in the forwardSlice of
// `transferRead` that operates on the same memref.
vector::TransferWriteOp transferWrite;
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
if (!candidateWrite || candidateWrite.source() != transferRead.source())
continue;
transferWrite = candidateWrite;
}
// All operands of the TransferRead must be defined outside of the loop.
for (auto operand : transferRead.getOperands())
if (!loop.isDefinedOutsideOfLoop(operand))
return WalkResult::advance();
// Only hoist transfer_read / transfer_write pairs for now.
if (!transferWrite)
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
<< "\n");
// Approximate aliasing by checking that:
// 1. indices are the same,
// 2. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
if (transferRead.indices() != transferWrite.indices() &&
transferRead.getVectorType() == transferWrite.getVectorType())
return WalkResult::advance();
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
return WalkResult::advance();
for (auto &use : transferRead.source().getUses()) {
if (!dom.properlyDominates(loop, use.getOwner()))
continue;
if (use.getOwner() == transferRead.getOperation() ||
use.getOwner() == transferWrite.getOperation())
continue;
if (auto transferWriteUse =
dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
if (!isDisjointTransferSet(
cast<VectorTransferOpInterface>(transferWrite.getOperation()),
cast<VectorTransferOpInterface>(
transferWriteUse.getOperation())))
return WalkResult::advance();
} else if (auto transferReadUse =
dyn_cast<vector::TransferReadOp>(use.getOwner())) {
if (!isDisjointTransferSet(
cast<VectorTransferOpInterface>(transferWrite.getOperation()),
cast<VectorTransferOpInterface>(
transferReadUse.getOperation())))
return WalkResult::advance();
} else {
// Unknown use, we cannot prove that it doesn't alias with the
// transferRead/transferWrite operations.
return WalkResult::advance();
}
}
// Hoist read before.
if (failed(loop.moveOutOfLoop({transferRead})))
llvm_unreachable(
"Unexpected failure to move transfer read out of loop");
// Hoist write after.
transferWrite->moveAfter(loop);
// Rewrite `loop` with new yields by cloning and erase the original loop.
OpBuilder b(transferRead);
auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
transferWrite.vector());
// Transfer write has been hoisted, need to update the written value to
// the value yielded by the newForOp.
transferWrite.vector().replaceAllUsesWith(
newForOp.getResults().take_back()[0]);
changed = true;
loop.erase();
// Need to interrupt and restart because erasing the loop messes up the
// walk.
return WalkResult::interrupt();
});
}
}