[mlir] Extended Liveness analysis to support nested regions.

The current Liveness analysis does not support operations with nested regions.
This causes issues when querying liveness information about blocks nested within
operations. Furthermore, the live-in and live-out sets are not computed properly
in these cases.

Differential Revision: https://reviews.llvm.org/D77714
This commit is contained in:
Marcel Koester 2020-04-08 10:31:18 +02:00
parent 62da6ecea2
commit c79227cabb
4 changed files with 278 additions and 107 deletions

View File

@ -110,6 +110,11 @@ public:
/// the operation with an offending use.
bool isIsolatedFromAbove(Optional<Location> noteLoc = llvm::None);
/// Returns 'block' if 'block' lies in this region, or otherwise finds the
/// ancestor of 'block' that lies in this region. Returns nullptr if the
/// latter fails.
Block *findAncestorBlockInRegion(Block &block);
/// Drop all operand uses from operations within this region, which is
/// an essential step in breaking cyclic dependences between references when
/// they are to be deleted.

View File

@ -31,51 +31,66 @@ struct BlockInfoBuilder {
/// Fills the block builder with initial liveness information.
BlockInfoBuilder(Block *block) : block(block) {
auto gatherOutValues = [&](Value value) {
// Check whether this value will be in the outValues set (its uses escape
// this block). Due to the SSA properties of the program, the uses must
// occur after the definition. Therefore, we do not have to check
// additional conditions to detect an escaping value.
for (Operation *useOp : value.getUsers()) {
Block *ownerBlock = useOp->getBlock();
// Find an owner block in the current region. Note that a value does not
// escape this block if it is used in a nested region.
ownerBlock = block->getParent()->findAncestorBlockInRegion(*ownerBlock);
assert(ownerBlock && "Use leaves the current parent region");
if (ownerBlock != block) {
outValues.insert(value);
break;
}
}
};
// Mark all block arguments (phis) as defined.
for (BlockArgument argument : block->getArguments())
for (BlockArgument argument : block->getArguments()) {
// Insert value into the set of defined values.
defValues.insert(argument);
// Check all result values and whether their uses
// are inside this block or not (see outValues).
for (Operation &operation : *block)
for (Value result : operation.getResults()) {
defValues.insert(result);
// Gather all out values of all arguments in the current block.
gatherOutValues(argument);
}
// Check whether this value will be in the outValues
// set (its uses escape this block). Due to the SSA
// properties of the program, the uses must occur after
// the definition. Therefore, we do not have to check
// additional conditions to detect an escaping value.
for (OpOperand &use : result.getUses())
if (use.getOwner()->getBlock() != block) {
outValues.insert(result);
break;
}
}
// Gather out values of all operations in the current block.
for (Operation &operation : *block)
for (Value result : operation.getResults())
gatherOutValues(result);
// Mark all nested operation results as defined.
block->walk([&](Operation *op) {
for (Value result : op->getResults())
defValues.insert(result);
});
// Check all operations for used operands.
for (Operation &operation : block->getOperations())
for (Value operand : operation.getOperands()) {
block->walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
// If the operand is already defined in the scope of this
// block, we can skip the value in the use set.
if (!defValues.count(operand))
useValues.insert(operand);
}
});
}
/// Updates live-in information of the current block.
/// To do so it uses the default liveness-computation formula:
/// newIn = use union out \ def.
/// The methods returns true, if the set has changed (newIn != in),
/// false otherwise.
/// Updates live-in information of the current block. To do so it uses the
/// default liveness-computation formula: newIn = use union out \ def. The
/// methods returns true, if the set has changed (newIn != in), false
/// otherwise.
bool updateLiveIn() {
ValueSetT newIn = useValues;
llvm::set_union(newIn, outValues);
llvm::set_subtract(newIn, defValues);
// It is sufficient to check the set sizes (instead of their contents)
// since the live-in set can only grow monotonically during all update
// operations.
// It is sufficient to check the set sizes (instead of their contents) since
// the live-in set can only grow monotonically during all update operations.
if (newIn.size() == inValues.size())
return false;
@ -83,9 +98,9 @@ struct BlockInfoBuilder {
return true;
}
/// Updates live-out information of the current block.
/// It iterates over all successors and unifies their live-in
/// values with the current live-out values.
/// Updates live-out information of the current block. It iterates over all
/// successors and unifies their live-in values with the current live-out
/// values.
template <typename SourceT> void updateLiveOut(SourceT &source) {
for (Block *succ : block->getSuccessors()) {
BlockInfoBuilder &builder = source[succ];
@ -110,20 +125,32 @@ struct BlockInfoBuilder {
};
} // namespace
/// Walks all regions (including nested regions recursively) and invokes the
/// given function for every block.
template <typename FuncT>
static void walkRegions(MutableArrayRef<Region> regions, const FuncT &func) {
for (Region &region : regions)
for (Block &block : region) {
func(block);
// Traverse all nested regions.
for (Operation &operation : block)
walkRegions(operation.getRegions(), func);
}
}
/// Builds the internal liveness block mapping.
static void buildBlockMapping(MutableArrayRef<Region> regions,
DenseMap<Block *, BlockInfoBuilder> &builders) {
llvm::SetVector<Block *> toProcess;
// Initialize all block structures
for (Region &region : regions)
for (Block &block : region) {
BlockInfoBuilder &builder =
builders.try_emplace(&block, &block).first->second;
walkRegions(regions, [&](Block &block) {
BlockInfoBuilder &builder =
builders.try_emplace(&block, &block).first->second;
if (builder.updateLiveIn())
toProcess.insert(block.pred_begin(), block.pred_end());
}
if (builder.updateLiveIn())
toProcess.insert(block.pred_begin(), block.pred_end());
});
// Propagate the in and out-value sets (fixpoint iteration)
while (!toProcess.empty()) {
@ -143,8 +170,8 @@ static void buildBlockMapping(MutableArrayRef<Region> regions,
// Liveness
//===----------------------------------------------------------------------===//
/// Creates a new Liveness analysis that computes liveness
/// information for all associated regions.
/// Creates a new Liveness analysis that computes liveness information for all
/// associated regions.
Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); }
/// Initializes the internal mappings.
@ -229,8 +256,8 @@ const Liveness::ValueSetT &Liveness::getLiveOut(Block *block) const {
return getLiveness(block)->out();
}
/// Returns true if the given operation represent the last use of the
/// given value.
/// Returns true if the given operation represent the last use of the given
/// value.
bool Liveness::isLastUse(Value value, Operation *operation) const {
Block *block = operation->getBlock();
const LivenessBlockInfo *blockInfo = getLiveness(block);
@ -257,22 +284,21 @@ void Liveness::print(raw_ostream &os) const {
DenseMap<Block *, size_t> blockIds;
DenseMap<Operation *, size_t> operationIds;
DenseMap<Value, size_t> valueIds;
for (Region &region : operation->getRegions())
for (Block &block : region) {
blockIds.insert({&block, blockIds.size()});
for (BlockArgument argument : block.getArguments())
valueIds.insert({argument, valueIds.size()});
for (Operation &operation : block) {
operationIds.insert({&operation, operationIds.size()});
for (Value result : operation.getResults())
valueIds.insert({result, valueIds.size()});
}
walkRegions(operation->getRegions(), [&](Block &block) {
blockIds.insert({&block, blockIds.size()});
for (BlockArgument argument : block.getArguments())
valueIds.insert({argument, valueIds.size()});
for (Operation &operation : block) {
operationIds.insert({&operation, operationIds.size()});
for (Value result : operation.getResults())
valueIds.insert({result, valueIds.size()});
}
});
// Local printing helpers
auto printValueRef = [&](Value value) {
if (Operation *defOp = value.getDefiningOp())
os << "val_" << defOp->getName();
os << "val_" << valueIds[value];
else {
auto blockArg = value.cast<BlockArgument>();
os << "arg" << blockArg.getArgNumber() << "@"
@ -292,39 +318,38 @@ void Liveness::print(raw_ostream &os) const {
};
// Dump information about in and out values.
for (Region &region : operation->getRegions())
for (Block &block : region) {
os << "// - Block: " << blockIds[&block] << "\n";
auto liveness = getLiveness(&block);
os << "// --- LiveIn: ";
printValueRefs(liveness->inValues);
os << "\n// --- LiveOut: ";
printValueRefs(liveness->outValues);
os << "\n";
walkRegions(operation->getRegions(), [&](Block &block) {
os << "// - Block: " << blockIds[&block] << "\n";
auto liveness = getLiveness(&block);
os << "// --- LiveIn: ";
printValueRefs(liveness->inValues);
os << "\n// --- LiveOut: ";
printValueRefs(liveness->outValues);
os << "\n";
// Print liveness intervals.
os << "// --- BeginLiveness";
for (Operation &op : block) {
if (op.getNumResults() < 1)
continue;
os << "\n";
for (Value result : op.getResults()) {
os << "// ";
printValueRef(result);
os << ":";
auto liveOperations = resolveLiveness(result);
std::sort(liveOperations.begin(), liveOperations.end(),
[&](Operation *left, Operation *right) {
return operationIds[left] < operationIds[right];
});
for (Operation *operation : liveOperations) {
os << "\n// ";
operation->print(os);
}
// Print liveness intervals.
os << "// --- BeginLiveness";
for (Operation &op : block) {
if (op.getNumResults() < 1)
continue;
os << "\n";
for (Value result : op.getResults()) {
os << "// ";
printValueRef(result);
os << ":";
auto liveOperations = resolveLiveness(result);
std::sort(liveOperations.begin(), liveOperations.end(),
[&](Operation *left, Operation *right) {
return operationIds[left] < operationIds[right];
});
for (Operation *operation : liveOperations) {
os << "\n// ";
operation->print(os);
}
}
os << "\n// --- EndLiveness\n";
}
os << "\n// --- EndLiveness\n";
});
os << "// -------------------\n";
}
@ -342,8 +367,8 @@ bool LivenessBlockInfo::isLiveOut(Value value) const {
return outValues.count(value);
}
/// Gets the start operation for the given value
/// (must be referenced in this block).
/// Gets the start operation for the given value (must be referenced in this
/// block).
Operation *LivenessBlockInfo::getStartOperation(Value value) const {
Operation *definingOp = value.getDefiningOp();
// The given value is either live-in or is defined
@ -363,13 +388,13 @@ Operation *LivenessBlockInfo::getEndOperation(Value value,
// Resolve the last operation (must exist by definition).
Operation *endOperation = startOperation;
for (OpOperand &use : value.getUses()) {
Operation *useOperation = use.getOwner();
// Check whether the use is in our block and after
// the current end operation.
if (useOperation->getBlock() == block &&
endOperation->isBeforeInBlock(useOperation))
endOperation = useOperation;
for (Operation *useOp : value.getUsers()) {
// Find the associated operation in the current block (if any).
useOp = block->findAncestorOpInBlock(*useOp);
// Check whether the use is in our block and after the current end
// operation.
if (useOp && endOperation->isBeforeInBlock(useOp))
endOperation = useOp;
}
return endOperation;
}

View File

@ -108,6 +108,20 @@ void Region::cloneInto(Region *dest, Region::iterator destPos,
it->walk(remapOperands);
}
/// Returns 'block' if 'block' lies in this region, or otherwise finds the
/// ancestor of 'block' that lies in this region. Returns nullptr if the latter
/// fails.
Block *Region::findAncestorBlockInRegion(Block &block) {
auto currBlock = &block;
while (currBlock->getParent() != this) {
Operation *parentOp = currBlock->getParentOp();
if (!parentOp || !parentOp->getBlock())
return nullptr;
currBlock = parentOp->getBlock();
}
return currBlock;
}
void Region::dropAllReferences() {
for (Block &b : *this)
b.dropAllReferences();

View File

@ -25,7 +25,7 @@ func @func_simpleBranch(%arg0: i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: LiveIn: arg0@0 arg1@0
// CHECK-NEXT: LiveOut:{{ *$}}
// CHECK-NEXT: BeginLiveness
// CHECK: val_std.addi
// CHECK: val_2
// CHECK-NEXT: %0 = addi
// CHECK-NEXT: return
// CHECK-NEXT: EndLiveness
@ -58,7 +58,7 @@ func @func_condBranch(%cond : i1, %arg1: i32, %arg2 : i32) -> i32 {
// CHECK-NEXT: LiveIn: arg1@0 arg2@0
// CHECK-NEXT: LiveOut:{{ *$}}
// CHECK-NEXT: BeginLiveness
// CHECK: val_std.addi
// CHECK: val_3
// CHECK-NEXT: %0 = addi
// CHECK-NEXT: return
// CHECK-NEXT: EndLiveness
@ -80,7 +80,7 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: LiveIn: arg1@0
// CHECK-NEXT: LiveOut: arg1@0 arg0@1
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_std.cmpi
// CHECK-NEXT: val_5
// CHECK-NEXT: %2 = cmpi
// CHECK-NEXT: cond_br
// CHECK-NEXT: EndLiveness
@ -91,11 +91,11 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
// CHECK-NEXT: LiveIn: arg1@0 arg0@1
// CHECK-NEXT: LiveOut: arg1@0
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_std.constant
// CHECK-NEXT: val_7
// CHECK-NEXT: %c
// CHECK-NEXT: %4 = addi
// CHECK-NEXT: %5 = addi
// CHECK-NEXT: val_std.addi
// CHECK-NEXT: val_8
// CHECK-NEXT: %4 = addi
// CHECK-NEXT: %5 = addi
// CHECK-NEXT: br
@ -118,33 +118,33 @@ func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 {
func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK: Block: 0
// CHECK-NEXT: LiveIn:{{ *$}}
// CHECK-NEXT: LiveOut: arg2@0 val_std.muli val_std.addi
// CHECK-NEXT: LiveOut: arg2@0 val_9 val_10
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_std.addi
// CHECK-NEXT: val_4
// CHECK-NEXT: %0 = addi
// CHECK-NEXT: %c
// CHECK-NEXT: %1 = addi
// CHECK-NEXT: %2 = addi
// CHECK-NEXT: %3 = muli
// CHECK-NEXT: val_std.constant
// CHECK-NEXT: val_5
// CHECK-NEXT: %c
// CHECK-NEXT: %1 = addi
// CHECK-NEXT: %2 = addi
// CHECK-NEXT: %3 = muli
// CHECK-NEXT: %4 = muli
// CHECK-NEXT: %5 = addi
// CHECK-NEXT: val_std.addi
// CHECK-NEXT: val_6
// CHECK-NEXT: %1 = addi
// CHECK-NEXT: %2 = addi
// CHECK-NEXT: %3 = muli
// CHECK-NEXT: val_std.addi
// CHECK-NEXT: val_7
// CHECK-NEXT %2 = addi
// CHECK-NEXT %3 = muli
// CHECK-NEXT %4 = muli
// CHECK: val_std.muli
// CHECK: val_8
// CHECK-NEXT: %3 = muli
// CHECK-NEXT: %4 = muli
// CHECK-NEXT: val_std.muli
// CHECK-NEXT: val_9
// CHECK-NEXT: %4 = muli
// CHECK-NEXT: %5 = addi
// CHECK-NEXT: cond_br
@ -152,7 +152,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK-NEXT: %6 = muli
// CHECK-NEXT: %7 = muli
// CHECK-NEXT: %8 = addi
// CHECK-NEXT: val_std.addi
// CHECK-NEXT: val_10
// CHECK-NEXT: %5 = addi
// CHECK-NEXT: cond_br
// CHECK-NEXT: %7
@ -168,7 +168,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
^bb1:
// CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg2@0 val_std.muli
// CHECK-NEXT: LiveIn: arg2@0 val_9
// CHECK-NEXT: LiveOut: arg2@0
%const4 = constant 4 : i32
%6 = muli %4, %const4 : i32
@ -176,7 +176,7 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
^bb2:
// CHECK: Block: 2
// CHECK-NEXT: LiveIn: arg2@0 val_std.muli val_std.addi
// CHECK-NEXT: LiveIn: arg2@0 val_9 val_10
// CHECK-NEXT: LiveOut: arg2@0
%7 = muli %4, %5 : i32
%8 = addi %4, %arg2 : i32
@ -188,4 +188,131 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
// CHECK-NEXT: LiveOut:{{ *$}}
%result = addi %sum, %arg2 : i32
return %result : i32
}
}
// -----
// CHECK-LABEL: Testing : nested_region
func @nested_region(
%arg0 : index, %arg1 : index, %arg2 : index,
%arg3 : i32, %arg4 : i32, %arg5 : i32,
%buffer : memref<i32>) -> i32 {
// CHECK: Block: 0
// CHECK-NEXT: LiveIn:{{ *$}}
// CHECK-NEXT: LiveOut:{{ *$}}
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_7
// CHECK-NEXT: %0 = addi
// CHECK-NEXT: %1 = addi
// CHECK-NEXT: loop.for
// CHECK: // %2 = addi
// CHECK-NEXT: %3 = addi
// CHECK-NEXT: val_8
// CHECK-NEXT: %1 = addi
// CHECK-NEXT: loop.for
// CHECK: // return %1
// CHECK: EndLiveness
%0 = addi %arg3, %arg4 : i32
%1 = addi %arg4, %arg5 : i32
loop.for %arg6 = %arg0 to %arg1 step %arg2 {
// CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg5@0 arg6@0 val_7
// CHECK-NEXT: LiveOut:{{ *$}}
%2 = addi %0, %arg5 : i32
%3 = addi %2, %0 : i32
store %3, %buffer[] : memref<i32>
}
return %1 : i32
}
// -----
// CHECK-LABEL: Testing : nested_region2
func @nested_region2(
// CHECK: Block: 0
// CHECK-NEXT: LiveIn:{{ *$}}
// CHECK-NEXT: LiveOut:{{ *$}}
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_7
// CHECK-NEXT: %0 = addi
// CHECK-NEXT: %1 = addi
// CHECK-NEXT: loop.for
// CHECK: // %2 = addi
// CHECK-NEXT: loop.for
// CHECK: // %3 = addi
// CHECK-NEXT: val_8
// CHECK-NEXT: %1 = addi
// CHECK-NEXT: loop.for
// CHECK: // return %1
// CHECK: EndLiveness
%arg0 : index, %arg1 : index, %arg2 : index,
%arg3 : i32, %arg4 : i32, %arg5 : i32,
%buffer : memref<i32>) -> i32 {
%0 = addi %arg3, %arg4 : i32
%1 = addi %arg4, %arg5 : i32
loop.for %arg6 = %arg0 to %arg1 step %arg2 {
// CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg0@0 arg1@0 arg2@0 arg5@0 arg6@0 val_7
// CHECK-NEXT: LiveOut:{{ *$}}
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_10
// CHECK-NEXT: %2 = addi
// CHECK-NEXT: loop.for
// CHECK: // %3 = addi
// CHECK: EndLiveness
%2 = addi %0, %arg5 : i32
loop.for %arg7 = %arg0 to %arg1 step %arg2 {
%3 = addi %2, %0 : i32
store %3, %buffer[] : memref<i32>
}
}
return %1 : i32
}
// -----
// CHECK-LABEL: Testing : nested_region3
func @nested_region3(
// CHECK: Block: 0
// CHECK-NEXT: LiveIn:{{ *$}}
// CHECK-NEXT: LiveOut: arg0@0 arg1@0 arg2@0 arg6@0 val_7 val_8
// CHECK-NEXT: BeginLiveness
// CHECK-NEXT: val_7
// CHECK-NEXT: %0 = addi
// CHECK-NEXT: %1 = addi
// CHECK-NEXT: loop.for
// CHECK: // br ^bb1
// CHECK-NEXT: %2 = addi
// CHECK-NEXT: loop.for
// CHECK: // %2 = addi
// CHECK: EndLiveness
%arg0 : index, %arg1 : index, %arg2 : index,
%arg3 : i32, %arg4 : i32, %arg5 : i32,
%buffer : memref<i32>) -> i32 {
%0 = addi %arg3, %arg4 : i32
%1 = addi %arg4, %arg5 : i32
loop.for %arg6 = %arg0 to %arg1 step %arg2 {
// CHECK: Block: 1
// CHECK-NEXT: LiveIn: arg5@0 arg6@0 val_7
// CHECK-NEXT: LiveOut:{{ *$}}
%2 = addi %0, %arg5 : i32
store %2, %buffer[] : memref<i32>
}
br ^exit
^exit:
// CHECK: Block: 2
// CHECK-NEXT: LiveIn: arg0@0 arg1@0 arg2@0 arg6@0 val_7 val_8
// CHECK-NEXT: LiveOut:{{ *$}}
loop.for %arg7 = %arg0 to %arg1 step %arg2 {
// CHECK: Block: 3
// CHECK-NEXT: LiveIn: arg6@0 val_7 val_8
// CHECK-NEXT: LiveOut:{{ *$}}
%2 = addi %0, %1 : i32
store %2, %buffer[] : memref<i32>
}
return %1 : i32
}