[mlir] Optimize the implementation of RegionDCE

The current implementation has some inefficiencies that become noticeable when running on large modules. This revision optimizes the code, and updates some out-dated idioms with newer utilities. The main components of this optimization include:

* Add an overload of Block::eraseArguments that allows for O(N) erasure of disjoint arguments.
* Don't process entry block arguments given that we don't erase them at this point.
* Don't track individual operation results, given that we don't erase them. We can just track the parent operation.

Differential Revision: https://reviews.llvm.org/D98309
This commit is contained in:
River Riddle 2021-03-10 16:13:25 -08:00
parent 70af0bf6fe
commit 4e02eb8014
3 changed files with 69 additions and 50 deletions

View File

@ -108,7 +108,10 @@ public:
void eraseArguments(ArrayRef<unsigned> argIndices);
/// Erases the arguments that have their corresponding bit set in
/// `eraseIndices` and removes them from the argument list.
void eraseArguments(llvm::BitVector eraseIndices);
void eraseArguments(const llvm::BitVector &eraseIndices);
/// Erases arguments using the given predicate. If the predicate returns true,
/// that argument is erased.
void eraseArguments(function_ref<bool(BlockArgument)> shouldEraseFn);
unsigned getNumArguments() { return arguments.size(); }
BlockArgument getArgument(unsigned i) { return arguments[i]; }

View File

@ -188,23 +188,32 @@ void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
eraseArguments(eraseIndices);
}
void Block::eraseArguments(llvm::BitVector eraseIndices) {
// We do this in reverse so that we erase later indices before earlier
// indices, to avoid shifting the later indices.
unsigned originalNumArgs = getNumArguments();
int64_t firstErased = originalNumArgs;
for (unsigned i = 0; i < originalNumArgs; ++i) {
int64_t currentPos = originalNumArgs - i - 1;
if (eraseIndices.test(currentPos)) {
arguments[currentPos].destroy();
arguments.erase(arguments.begin() + currentPos);
firstErased = currentPos;
void Block::eraseArguments(const llvm::BitVector &eraseIndices) {
eraseArguments(
[&](BlockArgument arg) { return eraseIndices.test(arg.getArgNumber()); });
}
void Block::eraseArguments(function_ref<bool(BlockArgument)> shouldEraseFn) {
auto firstDead = llvm::find_if(arguments, shouldEraseFn);
if (firstDead == arguments.end())
return;
// Destroy the first dead argument, this avoids reapplying the predicate to
// it.
unsigned index = firstDead->getArgNumber();
firstDead->destroy();
// Iterate the remaining arguments to remove any that are now dead.
for (auto it = std::next(firstDead), e = arguments.end(); it != e; ++it) {
// Destroy dead arguments, and shift those that are still live.
if (shouldEraseFn(*it)) {
it->destroy();
} else {
it->setArgNumber(index++);
*firstDead++ = *it;
}
}
// Update the cached position for the arguments after the first erased one.
int64_t index = firstErased;
for (BlockArgument arg : llvm::drop_begin(arguments, index))
arg.setArgNumber(index++);
arguments.erase(firstDead, arguments.end());
}
//===----------------------------------------------------------------------===//

View File

@ -139,9 +139,23 @@ namespace {
class LiveMap {
public:
/// Value methods.
bool wasProvenLive(Value value) { return liveValues.count(value); }
bool wasProvenLive(Value value) {
// TODO: For results that are removable, e.g. for region based control flow,
// we could allow for these values to be tracked independently.
if (OpResult result = value.dyn_cast<OpResult>())
return wasProvenLive(result.getOwner());
return wasProvenLive(value.cast<BlockArgument>());
}
bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
void setProvedLive(Value value) {
changed |= liveValues.insert(value).second;
// TODO: For results that are removable, e.g. for region based control flow,
// we could allow for these values to be tracked independently.
if (OpResult result = value.dyn_cast<OpResult>())
return setProvedLive(result.getOwner());
setProvedLive(value.cast<BlockArgument>());
}
void setProvedLive(BlockArgument arg) {
changed |= liveValues.insert(arg).second;
}
/// Operation methods.
@ -192,15 +206,6 @@ static void processValue(Value value, LiveMap &liveMap) {
liveMap.setProvedLive(value);
}
static bool isOpIntrinsicallyLive(Operation *op) {
// This pass doesn't modify the CFG, so terminators are never deleted.
if (op->mightHaveTrait<OpTrait::IsTerminator>())
return true;
// If the op has a side effect, we treat it as live.
// TODO: Properly handle region side effects.
return !MemoryEffectOpInterface::hasNoEffect(op) || op->getNumRegions() != 0;
}
static void propagateLiveness(Region &region, LiveMap &liveMap);
static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
@ -226,9 +231,6 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
}
static void propagateLiveness(Operation *op, LiveMap &liveMap) {
// All Value's are either a block argument or an op result.
// We call processValue on those cases.
// Recurse on any regions the op has.
for (Region &region : op->getRegions())
propagateLiveness(region, liveMap);
@ -237,18 +239,17 @@ static void propagateLiveness(Operation *op, LiveMap &liveMap) {
if (op->hasTrait<OpTrait::IsTerminator>())
return propagateTerminatorLiveness(op, liveMap);
// Process the op itself.
if (isOpIntrinsicallyLive(op)) {
liveMap.setProvedLive(op);
// Don't reprocess live operations.
if (liveMap.wasProvenLive(op))
return;
}
// Process the op itself.
if (!wouldOpBeTriviallyDead(op))
return liveMap.setProvedLive(op);
// If the op isn't intrinsically alive, check it's results.
for (Value value : op->getResults())
processValue(value, liveMap);
bool provedLive = llvm::any_of(op->getResults(), [&](Value value) {
return liveMap.wasProvenLive(value);
});
if (provedLive)
liveMap.setProvedLive(op);
}
static void propagateLiveness(Region &region, LiveMap &liveMap) {
@ -260,8 +261,18 @@ static void propagateLiveness(Region &region, LiveMap &liveMap) {
// faster convergence to a fixed point (we try to visit uses before defs).
for (Operation &op : llvm::reverse(block->getOperations()))
propagateLiveness(&op, liveMap);
for (Value value : block->getArguments())
processValue(value, liveMap);
// We currently do not remove entry block arguments, so there is no need to
// track their liveness.
// TODO: We could track these and enable removing dead operands/arguments
// from region control flow operations.
if (block->isEntryBlock())
continue;
for (Value value : block->getArguments()) {
if (!liveMap.wasProvenLive(value))
processValue(value, liveMap);
}
}
}
@ -314,11 +325,12 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
for (Operation &childOp :
llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
erasedAnything |=
succeeded(deleteDeadness(childOp.getRegions(), liveMap));
if (!liveMap.wasProvenLive(&childOp)) {
erasedAnything = true;
childOp.erase();
} else {
erasedAnything |=
succeeded(deleteDeadness(childOp.getRegions(), liveMap));
}
}
}
@ -326,13 +338,8 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
// The entry block has an unknown contract with their enclosing block, so
// skip it.
for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
// Iterate in reverse to avoid shifting later arguments when deleting
// earlier arguments.
for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
block.eraseArgument(e - i - 1);
erasedAnything = true;
}
block.eraseArguments(
[&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
}
}
return success(erasedAnything);