forked from OSchip/llvm-project
932 lines
35 KiB
C++
932 lines
35 KiB
C++
//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This transformation pass performs a sparse conditional constant propagation
|
|
// in MLIR. It identifies values known to be constant, propagates that
|
|
// information throughout the IR, and replaces them. This is done with an
|
|
// optimistic dataflow analysis that assumes that all values are constant until
|
|
// proven otherwise.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// This class represents a single lattice value. A lattive value corresponds to
|
|
/// the various different states that a value in the SCCP dataflow analysis can
|
|
/// take. See 'Kind' below for more details on the different states a value can
|
|
/// take.
|
|
class LatticeValue {
|
|
enum Kind {
|
|
/// A value with a yet to be determined value. This state may be changed to
|
|
/// anything.
|
|
Unknown,
|
|
|
|
/// A value that is known to be a constant. This state may be changed to
|
|
/// overdefined.
|
|
Constant,
|
|
|
|
/// A value that cannot statically be determined to be a constant. This
|
|
/// state cannot be changed.
|
|
Overdefined
|
|
};
|
|
|
|
public:
|
|
/// Initialize a lattice value with "Unknown".
|
|
LatticeValue()
|
|
: constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {}
|
|
/// Initialize a lattice value with a constant.
|
|
LatticeValue(Attribute attr, Dialect *dialect)
|
|
: constantAndTag(attr, Kind::Constant), constantDialect(dialect) {}
|
|
|
|
/// Returns true if this lattice value is unknown.
|
|
bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; }
|
|
|
|
/// Mark the lattice value as overdefined.
|
|
void markOverdefined() {
|
|
constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
|
|
constantDialect = nullptr;
|
|
}
|
|
|
|
/// Returns true if the lattice is overdefined.
|
|
bool isOverdefined() const {
|
|
return constantAndTag.getInt() == Kind::Overdefined;
|
|
}
|
|
|
|
/// Mark the lattice value as constant.
|
|
void markConstant(Attribute value, Dialect *dialect) {
|
|
constantAndTag.setPointerAndInt(value, Kind::Constant);
|
|
constantDialect = dialect;
|
|
}
|
|
|
|
/// If this lattice is constant, return the constant. Returns nullptr
|
|
/// otherwise.
|
|
Attribute getConstant() const { return constantAndTag.getPointer(); }
|
|
|
|
/// If this lattice is constant, return the dialect to use when materializing
|
|
/// the constant.
|
|
Dialect *getConstantDialect() const {
|
|
assert(getConstant() && "expected valid constant");
|
|
return constantDialect;
|
|
}
|
|
|
|
/// Merge in the value of the 'rhs' lattice into this one. Returns true if the
|
|
/// lattice value changed.
|
|
bool meet(const LatticeValue &rhs) {
|
|
// If we are already overdefined, or rhs is unknown, there is nothing to do.
|
|
if (isOverdefined() || rhs.isUnknown())
|
|
return false;
|
|
// If we are unknown, just take the value of rhs.
|
|
if (isUnknown()) {
|
|
constantAndTag = rhs.constantAndTag;
|
|
constantDialect = rhs.constantDialect;
|
|
return true;
|
|
}
|
|
|
|
// Otherwise, if this value doesn't match rhs go straight to overdefined.
|
|
if (constantAndTag != rhs.constantAndTag) {
|
|
markOverdefined();
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
private:
|
|
/// The attribute value if this is a constant and the tag for the element
|
|
/// kind.
|
|
llvm::PointerIntPair<Attribute, 2, Kind> constantAndTag;
|
|
|
|
/// The dialect the constant originated from. This is only valid if the
|
|
/// lattice is a constant. This is not used as part of the key, and is only
|
|
/// needed to materialize the held constant if necessary.
|
|
Dialect *constantDialect;
|
|
};
|
|
|
|
/// This class contains various state used when computing the lattice of a
|
|
/// callable operation.
|
|
class CallableLatticeState {
|
|
public:
|
|
/// Build a lattice state with a given callable region, and a specified number
|
|
/// of results to be initialized to the default lattice value (Unknown).
|
|
CallableLatticeState(Region *callableRegion, unsigned numResults)
|
|
: callableArguments(callableRegion->getArguments()),
|
|
resultLatticeValues(numResults) {}
|
|
|
|
/// Returns the arguments to the callable region.
|
|
Block::BlockArgListType getCallableArguments() const {
|
|
return callableArguments;
|
|
}
|
|
|
|
/// Returns the lattice value for the results of the callable region.
|
|
MutableArrayRef<LatticeValue> getResultLatticeValues() {
|
|
return resultLatticeValues;
|
|
}
|
|
|
|
/// Add a call to this callable. This is only used if the callable defines a
|
|
/// symbol.
|
|
void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
|
|
|
|
/// Return the calls that reference this callable. This is only used
|
|
/// if the callable defines a symbol.
|
|
ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
|
|
|
|
private:
|
|
/// The arguments of the callable region.
|
|
Block::BlockArgListType callableArguments;
|
|
|
|
/// The lattice state for each of the results of this region. The return
|
|
/// values of the callable aren't SSA values, so we need to track them
|
|
/// separately.
|
|
SmallVector<LatticeValue, 4> resultLatticeValues;
|
|
|
|
/// The calls referencing this callable if this callable defines a symbol.
|
|
/// This removes the need to recompute symbol references during propagation.
|
|
/// Value based references are trivial to resolve, so they can be done
|
|
/// in-place.
|
|
SmallVector<Operation *, 4> symbolCalls;
|
|
};
|
|
|
|
/// This class represents the solver for the SCCP analysis. This class acts as
|
|
/// the propagation engine for computing which values form constants.
|
|
class SCCPSolver {
|
|
public:
|
|
/// Initialize the solver with the given top-level operation.
|
|
SCCPSolver(Operation *op);
|
|
|
|
/// Run the solver until it converges.
|
|
void solve();
|
|
|
|
/// Rewrite the given regions using the computing analysis. This replaces the
|
|
/// uses of all values that have been computed to be constant, and erases as
|
|
/// many newly dead operations.
|
|
void rewrite(MLIRContext *context, MutableArrayRef<Region> regions);
|
|
|
|
private:
|
|
/// Initialize the set of symbol defining callables that can have their
|
|
/// arguments and results tracked. 'op' is the top-level operation that SCCP
|
|
/// is operating on.
|
|
void initializeSymbolCallables(Operation *op);
|
|
|
|
/// Replace the given value with a constant if the corresponding lattice
|
|
/// represents a constant. Returns success if the value was replaced, failure
|
|
/// otherwise.
|
|
LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder,
|
|
Value value);
|
|
|
|
/// Visit the users of the given IR that reside within executable blocks.
|
|
template <typename T>
|
|
void visitUsers(T &value) {
|
|
for (Operation *user : value.getUsers())
|
|
if (isBlockExecutable(user->getBlock()))
|
|
visitOperation(user);
|
|
}
|
|
|
|
/// Visit the given operation and compute any necessary lattice state.
|
|
void visitOperation(Operation *op);
|
|
|
|
/// Visit the given call operation and compute any necessary lattice state.
|
|
void visitCallOperation(CallOpInterface op);
|
|
|
|
/// Visit the given callable operation and compute any necessary lattice
|
|
/// state.
|
|
void visitCallableOperation(Operation *op);
|
|
|
|
/// Visit the given operation, which defines regions, and compute any
|
|
/// necessary lattice state. This also resolves the lattice state of both the
|
|
/// operation results and any nested regions.
|
|
void visitRegionOperation(Operation *op,
|
|
ArrayRef<Attribute> constantOperands);
|
|
|
|
/// Visit the given set of region successors, computing any necessary lattice
|
|
/// state. The provided function returns the input operands to the region at
|
|
/// the given index. If the index is 'None', the input operands correspond to
|
|
/// the parent operation results.
|
|
void visitRegionSuccessors(
|
|
Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
|
|
function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
|
|
|
|
/// Visit the given terminator operation and compute any necessary lattice
|
|
/// state.
|
|
void visitTerminatorOperation(Operation *op,
|
|
ArrayRef<Attribute> constantOperands);
|
|
|
|
/// Visit the given terminator operation that exits a callable region. These
|
|
/// are terminators with no CFG successors.
|
|
void visitCallableTerminatorOperation(Operation *callable,
|
|
Operation *terminator);
|
|
|
|
/// Visit the given block and compute any necessary lattice state.
|
|
void visitBlock(Block *block);
|
|
|
|
/// Visit argument #'i' of the given block and compute any necessary lattice
|
|
/// state.
|
|
void visitBlockArgument(Block *block, int i);
|
|
|
|
/// Mark the entry block of the given region as executable. Returns false if
|
|
/// the block was already marked executable. If `markArgsOverdefined` is true,
|
|
/// the arguments of the entry block are also set to overdefined.
|
|
bool markEntryBlockExecutable(Region *region, bool markArgsOverdefined);
|
|
|
|
/// Mark the given block as executable. Returns false if the block was already
|
|
/// marked executable.
|
|
bool markBlockExecutable(Block *block);
|
|
|
|
/// Returns true if the given block is executable.
|
|
bool isBlockExecutable(Block *block) const;
|
|
|
|
/// Mark the edge between 'from' and 'to' as executable.
|
|
void markEdgeExecutable(Block *from, Block *to);
|
|
|
|
/// Return true if the edge between 'from' and 'to' is executable.
|
|
bool isEdgeExecutable(Block *from, Block *to) const;
|
|
|
|
/// Mark the given value as overdefined. This means that we cannot refine a
|
|
/// specific constant for this value.
|
|
void markOverdefined(Value value);
|
|
|
|
/// Mark all of the given values as overdefined.
|
|
template <typename ValuesT>
|
|
void markAllOverdefined(ValuesT values) {
|
|
for (auto value : values)
|
|
markOverdefined(value);
|
|
}
|
|
template <typename ValuesT>
|
|
void markAllOverdefined(Operation *op, ValuesT values) {
|
|
markAllOverdefined(values);
|
|
opWorklist.push_back(op);
|
|
}
|
|
template <typename ValuesT>
|
|
void markAllOverdefinedAndVisitUsers(ValuesT values) {
|
|
for (auto value : values) {
|
|
auto &lattice = latticeValues[value];
|
|
if (!lattice.isOverdefined()) {
|
|
lattice.markOverdefined();
|
|
visitUsers(value);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Returns true if the given value was marked as overdefined.
|
|
bool isOverdefined(Value value) const;
|
|
|
|
/// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
|
|
/// corresponds to the parent operation of 'to'.
|
|
void meet(Operation *owner, LatticeValue &to, const LatticeValue &from);
|
|
|
|
/// The lattice for each SSA value.
|
|
DenseMap<Value, LatticeValue> latticeValues;
|
|
|
|
/// The set of blocks that are known to execute, or are intrinsically live.
|
|
SmallPtrSet<Block *, 16> executableBlocks;
|
|
|
|
/// The set of control flow edges that are known to execute.
|
|
DenseSet<std::pair<Block *, Block *>> executableEdges;
|
|
|
|
/// A worklist containing blocks that need to be processed.
|
|
SmallVector<Block *, 64> blockWorklist;
|
|
|
|
/// A worklist of operations that need to be processed.
|
|
SmallVector<Operation *, 64> opWorklist;
|
|
|
|
/// The callable operations that have their argument/result state tracked.
|
|
DenseMap<Operation *, CallableLatticeState> callableLatticeState;
|
|
|
|
/// A map between a call operation and the resolved symbol callable. This
|
|
/// avoids re-resolving symbol references during propagation. Value based
|
|
/// callables are trivial to resolve, so they can be done in-place.
|
|
DenseMap<Operation *, Operation *> callToSymbolCallable;
|
|
|
|
/// A symbol table used for O(1) symbol lookups during simplification.
|
|
SymbolTableCollection symbolTable;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
SCCPSolver::SCCPSolver(Operation *op) {
|
|
/// Initialize the solver with the regions within this operation.
|
|
for (Region ®ion : op->getRegions()) {
|
|
// Mark the entry block as executable. The values passed to these regions
|
|
// are also invisible, so mark any arguments as overdefined.
|
|
markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
|
|
}
|
|
initializeSymbolCallables(op);
|
|
}
|
|
|
|
void SCCPSolver::solve() {
|
|
while (!blockWorklist.empty() || !opWorklist.empty()) {
|
|
// Process any operations in the op worklist.
|
|
while (!opWorklist.empty())
|
|
visitUsers(*opWorklist.pop_back_val());
|
|
|
|
// Process any blocks in the block worklist.
|
|
while (!blockWorklist.empty())
|
|
visitBlock(blockWorklist.pop_back_val());
|
|
}
|
|
}
|
|
|
|
void SCCPSolver::rewrite(MLIRContext *context,
|
|
MutableArrayRef<Region> initialRegions) {
|
|
SmallVector<Block *, 8> worklist;
|
|
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
|
|
for (Region ®ion : regions)
|
|
for (Block &block : region)
|
|
if (isBlockExecutable(&block))
|
|
worklist.push_back(&block);
|
|
};
|
|
|
|
// An operation folder used to create and unique constants.
|
|
OperationFolder folder(context);
|
|
OpBuilder builder(context);
|
|
|
|
addToWorklist(initialRegions);
|
|
while (!worklist.empty()) {
|
|
Block *block = worklist.pop_back_val();
|
|
|
|
// Replace any block arguments with constants.
|
|
builder.setInsertionPointToStart(block);
|
|
for (BlockArgument arg : block->getArguments())
|
|
replaceWithConstant(builder, folder, arg);
|
|
|
|
for (Operation &op : llvm::make_early_inc_range(*block)) {
|
|
builder.setInsertionPoint(&op);
|
|
|
|
// Replace any result with constants.
|
|
bool replacedAll = op.getNumResults() != 0;
|
|
for (Value res : op.getResults())
|
|
replacedAll &= succeeded(replaceWithConstant(builder, folder, res));
|
|
|
|
// If all of the results of the operation were replaced, try to erase
|
|
// the operation completely.
|
|
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
|
|
assert(op.use_empty() && "expected all uses to be replaced");
|
|
op.erase();
|
|
continue;
|
|
}
|
|
|
|
// Add any the regions of this operation to the worklist.
|
|
addToWorklist(op.getRegions());
|
|
}
|
|
}
|
|
}
|
|
|
|
void SCCPSolver::initializeSymbolCallables(Operation *op) {
|
|
// Initialize the set of symbol callables that can have their state tracked.
|
|
// This tracks which symbol callable operations we can propagate within and
|
|
// out of.
|
|
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
|
|
Region &symbolTableRegion = symTable->getRegion(0);
|
|
Block *symbolTableBlock = &symbolTableRegion.front();
|
|
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
|
|
// We won't be able to track external callables.
|
|
Region *callableRegion = callable.getCallableRegion();
|
|
if (!callableRegion)
|
|
continue;
|
|
// We only care about symbol defining callables here.
|
|
auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
|
|
if (!symbol)
|
|
continue;
|
|
callableLatticeState.try_emplace(callable, callableRegion,
|
|
callable.getCallableResults().size());
|
|
|
|
// If not all of the uses of this symbol are visible, we can't track the
|
|
// state of the arguments.
|
|
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
|
|
for (Region ®ion : callable->getRegions())
|
|
markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
|
|
}
|
|
}
|
|
if (callableLatticeState.empty())
|
|
return;
|
|
|
|
// After computing the valid callables, walk any symbol uses to check
|
|
// for non-call references. We won't be able to track the lattice state
|
|
// for arguments to these callables, as we can't guarantee that we can see
|
|
// all of its calls.
|
|
Optional<SymbolTable::UseRange> uses =
|
|
SymbolTable::getSymbolUses(&symbolTableRegion);
|
|
if (!uses) {
|
|
// If we couldn't gather the symbol uses, conservatively assume that
|
|
// we can't track information for any nested symbols.
|
|
op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
|
|
return;
|
|
}
|
|
|
|
for (const SymbolTable::SymbolUse &use : *uses) {
|
|
// If the use is a call, track it to avoid the need to recompute the
|
|
// reference later.
|
|
if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
|
|
Operation *symCallable = callOp.resolveCallable(&symbolTable);
|
|
auto callableLatticeIt = callableLatticeState.find(symCallable);
|
|
if (callableLatticeIt != callableLatticeState.end()) {
|
|
callToSymbolCallable.try_emplace(callOp, symCallable);
|
|
|
|
// We only need to record the call in the lattice if it produces any
|
|
// values.
|
|
if (callOp->getNumResults())
|
|
callableLatticeIt->second.addSymbolCall(callOp);
|
|
}
|
|
continue;
|
|
}
|
|
// This use isn't a call, so don't we know all of the callers.
|
|
auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
|
|
auto it = callableLatticeState.find(symbol);
|
|
if (it != callableLatticeState.end()) {
|
|
for (Region ®ion : it->first->getRegions())
|
|
markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
|
|
}
|
|
}
|
|
};
|
|
SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
|
|
walkFn);
|
|
}
|
|
|
|
LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder,
|
|
OperationFolder &folder,
|
|
Value value) {
|
|
auto it = latticeValues.find(value);
|
|
auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant();
|
|
if (!attr)
|
|
return failure();
|
|
|
|
// Attempt to materialize a constant for the given value.
|
|
Dialect *dialect = it->second.getConstantDialect();
|
|
Value constant = folder.getOrCreateConstant(builder, dialect, attr,
|
|
value.getType(), value.getLoc());
|
|
if (!constant)
|
|
return failure();
|
|
|
|
value.replaceAllUsesWith(constant);
|
|
latticeValues.erase(it);
|
|
return success();
|
|
}
|
|
|
|
void SCCPSolver::visitOperation(Operation *op) {
|
|
// Collect all of the constant operands feeding into this operation. If any
|
|
// are not ready to be resolved, bail out and wait for them to resolve.
|
|
SmallVector<Attribute, 8> operandConstants;
|
|
operandConstants.reserve(op->getNumOperands());
|
|
for (Value operand : op->getOperands()) {
|
|
// Make sure all of the operands are resolved first.
|
|
auto &operandLattice = latticeValues[operand];
|
|
if (operandLattice.isUnknown())
|
|
return;
|
|
operandConstants.push_back(operandLattice.getConstant());
|
|
}
|
|
|
|
// If this is a terminator operation, process any control flow lattice state.
|
|
if (op->isKnownTerminator())
|
|
visitTerminatorOperation(op, operandConstants);
|
|
|
|
// Process call operations. The call visitor processes result values, so we
|
|
// can exit afterwards.
|
|
if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
|
|
return visitCallOperation(call);
|
|
|
|
// Process callable operations. These are specially handled region operations
|
|
// that track dataflow via calls.
|
|
if (isa<CallableOpInterface>(op)) {
|
|
// If this callable has a tracked lattice state, it will be visited by calls
|
|
// that reference it instead. This way, we don't assume that it is
|
|
// executable unless there is a proper reference to it.
|
|
if (callableLatticeState.count(op))
|
|
return;
|
|
return visitCallableOperation(op);
|
|
}
|
|
|
|
// Process region holding operations. The region visitor processes result
|
|
// values, so we can exit afterwards.
|
|
if (op->getNumRegions())
|
|
return visitRegionOperation(op, operandConstants);
|
|
|
|
// If this op produces no results, it can't produce any constants.
|
|
if (op->getNumResults() == 0)
|
|
return;
|
|
|
|
// If all of the results of this operation are already overdefined, bail out
|
|
// early.
|
|
auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); };
|
|
if (llvm::all_of(op->getResults(), isOverdefinedFn))
|
|
return;
|
|
|
|
// Save the original operands and attributes just in case the operation folds
|
|
// in-place. The constant passed in may not correspond to the real runtime
|
|
// value, so in-place updates are not allowed.
|
|
SmallVector<Value, 8> originalOperands(op->getOperands());
|
|
DictionaryAttr originalAttrs = op->getAttrDictionary();
|
|
|
|
// Simulate the result of folding this operation to a constant. If folding
|
|
// fails or was an in-place fold, mark the results as overdefined.
|
|
SmallVector<OpFoldResult, 8> foldResults;
|
|
foldResults.reserve(op->getNumResults());
|
|
if (failed(op->fold(operandConstants, foldResults)))
|
|
return markAllOverdefined(op, op->getResults());
|
|
|
|
// If the folding was in-place, mark the results as overdefined and reset the
|
|
// operation. We don't allow in-place folds as the desire here is for
|
|
// simulated execution, and not general folding.
|
|
if (foldResults.empty()) {
|
|
op->setOperands(originalOperands);
|
|
op->setAttrs(originalAttrs);
|
|
return markAllOverdefined(op, op->getResults());
|
|
}
|
|
|
|
// Merge the fold results into the lattice for this operation.
|
|
assert(foldResults.size() == op->getNumResults() && "invalid result size");
|
|
Dialect *opDialect = op->getDialect();
|
|
for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
|
|
LatticeValue &resultLattice = latticeValues[op->getResult(i)];
|
|
|
|
// Merge in the result of the fold, either a constant or a value.
|
|
OpFoldResult foldResult = foldResults[i];
|
|
if (Attribute foldAttr = foldResult.dyn_cast<Attribute>())
|
|
meet(op, resultLattice, LatticeValue(foldAttr, opDialect));
|
|
else
|
|
meet(op, resultLattice, latticeValues[foldResult.get<Value>()]);
|
|
}
|
|
}
|
|
|
|
void SCCPSolver::visitCallableOperation(Operation *op) {
|
|
// Mark the regions as executable. If we aren't tracking lattice state for
|
|
// this callable, mark all of the region arguments as overdefined.
|
|
bool isTrackingLatticeState = callableLatticeState.count(op);
|
|
for (Region ®ion : op->getRegions())
|
|
markEntryBlockExecutable(®ion, !isTrackingLatticeState);
|
|
|
|
// TODO: Add support for non-symbol callables when necessary. If the callable
|
|
// has non-call uses we would mark overdefined, otherwise allow for
|
|
// propagating the return values out.
|
|
markAllOverdefined(op, op->getResults());
|
|
}
|
|
|
|
void SCCPSolver::visitCallOperation(CallOpInterface op) {
|
|
ResultRange callResults = op->getResults();
|
|
|
|
// Resolve the callable operation for this call.
|
|
Operation *callableOp = nullptr;
|
|
if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
|
|
callableOp = callableValue.getDefiningOp();
|
|
else
|
|
callableOp = callToSymbolCallable.lookup(op);
|
|
|
|
// The callable of this call can't be resolved, mark any results overdefined.
|
|
if (!callableOp)
|
|
return markAllOverdefined(op, callResults);
|
|
|
|
// If this callable is tracking state, merge the argument operands with the
|
|
// arguments of the callable.
|
|
auto callableLatticeIt = callableLatticeState.find(callableOp);
|
|
if (callableLatticeIt == callableLatticeState.end())
|
|
return markAllOverdefined(op, callResults);
|
|
|
|
OperandRange callOperands = op.getArgOperands();
|
|
auto callableArgs = callableLatticeIt->second.getCallableArguments();
|
|
for (auto it : llvm::zip(callOperands, callableArgs)) {
|
|
BlockArgument callableArg = std::get<1>(it);
|
|
if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)]))
|
|
visitUsers(callableArg);
|
|
}
|
|
|
|
// Visit the callable.
|
|
visitCallableOperation(callableOp);
|
|
|
|
// Merge in the lattice state for the callable results as well.
|
|
auto callableResults = callableLatticeIt->second.getResultLatticeValues();
|
|
for (auto it : llvm::zip(callResults, callableResults))
|
|
meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)],
|
|
/*from=*/std::get<1>(it));
|
|
}
|
|
|
|
void SCCPSolver::visitRegionOperation(Operation *op,
|
|
ArrayRef<Attribute> constantOperands) {
|
|
// Check to see if we can reason about the internal control flow of this
|
|
// region operation.
|
|
auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
|
|
if (!regionInterface) {
|
|
// If we can't, conservatively mark all regions as executable.
|
|
for (Region ®ion : op->getRegions())
|
|
markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
|
|
|
|
// Don't try to simulate the results of a region operation as we can't
|
|
// guarantee that folding will be out-of-place. We don't allow in-place
|
|
// folds as the desire here is for simulated execution, and not general
|
|
// folding.
|
|
return markAllOverdefined(op, op->getResults());
|
|
}
|
|
|
|
// Check to see which regions are executable.
|
|
SmallVector<RegionSuccessor, 1> successors;
|
|
regionInterface.getSuccessorRegions(/*index=*/llvm::None, constantOperands,
|
|
successors);
|
|
|
|
// If the interface identified that no region will be executed. Mark
|
|
// any results of this operation as overdefined, as we can't reason about
|
|
// them.
|
|
// TODO: If we had an interface to detect pass through operands, we could
|
|
// resolve some results based on the lattice state of the operands. We could
|
|
// also allow for the parent operation to have itself as a region successor.
|
|
if (successors.empty())
|
|
return markAllOverdefined(op, op->getResults());
|
|
return visitRegionSuccessors(op, successors, [&](Optional<unsigned> index) {
|
|
assert(index && "expected valid region index");
|
|
return regionInterface.getSuccessorEntryOperands(*index);
|
|
});
|
|
}
|
|
|
|
void SCCPSolver::visitRegionSuccessors(
|
|
Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
|
|
function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
|
|
for (const RegionSuccessor &it : regionSuccessors) {
|
|
Region *region = it.getSuccessor();
|
|
ValueRange succArgs = it.getSuccessorInputs();
|
|
|
|
// Check to see if this is the parent operation.
|
|
if (!region) {
|
|
ResultRange results = parentOp->getResults();
|
|
if (llvm::all_of(results, [&](Value res) { return isOverdefined(res); }))
|
|
continue;
|
|
|
|
// Mark the results outside of the input range as overdefined.
|
|
if (succArgs.size() != results.size()) {
|
|
opWorklist.push_back(parentOp);
|
|
if (succArgs.empty())
|
|
return markAllOverdefined(results);
|
|
|
|
unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
|
|
markAllOverdefined(results.take_front(firstResIdx));
|
|
markAllOverdefined(results.drop_front(firstResIdx + succArgs.size()));
|
|
}
|
|
|
|
// Update the lattice for any operation results.
|
|
OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
|
|
for (auto it : llvm::zip(succArgs, operands))
|
|
meet(parentOp, latticeValues[std::get<0>(it)],
|
|
latticeValues[std::get<1>(it)]);
|
|
return;
|
|
}
|
|
assert(!region->empty() && "expected region to be non-empty");
|
|
Block *entryBlock = ®ion->front();
|
|
markBlockExecutable(entryBlock);
|
|
|
|
// If all of the arguments are already overdefined, the arguments have
|
|
// already been fully resolved.
|
|
auto arguments = entryBlock->getArguments();
|
|
if (llvm::all_of(arguments, [&](Value arg) { return isOverdefined(arg); }))
|
|
continue;
|
|
|
|
// Mark any arguments that do not receive inputs as overdefined, we won't be
|
|
// able to discern if they are constant.
|
|
if (succArgs.size() != arguments.size()) {
|
|
if (succArgs.empty()) {
|
|
markAllOverdefined(arguments);
|
|
continue;
|
|
}
|
|
|
|
unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
|
|
markAllOverdefinedAndVisitUsers(arguments.take_front(firstArgIdx));
|
|
markAllOverdefinedAndVisitUsers(
|
|
arguments.drop_front(firstArgIdx + succArgs.size()));
|
|
}
|
|
|
|
// Update the lattice for arguments that have inputs from the predecessor.
|
|
OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
|
|
for (auto it : llvm::zip(succArgs, succOperands)) {
|
|
LatticeValue &argLattice = latticeValues[std::get<0>(it)];
|
|
if (argLattice.meet(latticeValues[std::get<1>(it)]))
|
|
visitUsers(std::get<0>(it));
|
|
}
|
|
}
|
|
}
|
|
|
|
void SCCPSolver::visitTerminatorOperation(
|
|
Operation *op, ArrayRef<Attribute> constantOperands) {
|
|
// If this operation has no successors, we treat it as an exiting terminator.
|
|
if (op->getNumSuccessors() == 0) {
|
|
Region *parentRegion = op->getParentRegion();
|
|
Operation *parentOp = parentRegion->getParentOp();
|
|
|
|
// Check to see if this is a terminator for a callable region.
|
|
if (isa<CallableOpInterface>(parentOp))
|
|
return visitCallableTerminatorOperation(parentOp, op);
|
|
|
|
// Otherwise, check to see if the parent tracks region control flow.
|
|
auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
|
|
if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
|
|
return;
|
|
|
|
// Query the set of successors from the current region.
|
|
SmallVector<RegionSuccessor, 1> regionSuccessors;
|
|
regionInterface.getSuccessorRegions(parentRegion->getRegionNumber(),
|
|
constantOperands, regionSuccessors);
|
|
if (regionSuccessors.empty())
|
|
return;
|
|
|
|
// If this terminator is not "region-like", conservatively mark all of the
|
|
// successor values as overdefined.
|
|
if (!op->hasTrait<OpTrait::ReturnLike>()) {
|
|
for (auto &it : regionSuccessors)
|
|
markAllOverdefinedAndVisitUsers(it.getSuccessorInputs());
|
|
return;
|
|
}
|
|
|
|
// Otherwise, propagate the operand lattice states to each of the
|
|
// successors.
|
|
OperandRange operands = op->getOperands();
|
|
return visitRegionSuccessors(parentOp, regionSuccessors,
|
|
[&](Optional<unsigned>) { return operands; });
|
|
}
|
|
|
|
// Try to resolve to a specific successor with the constant operands.
|
|
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
|
|
if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
|
|
markEdgeExecutable(op->getBlock(), singleSucc);
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Otherwise, conservatively treat all edges as executable.
|
|
Block *block = op->getBlock();
|
|
for (Block *succ : op->getSuccessors())
|
|
markEdgeExecutable(block, succ);
|
|
}
|
|
|
|
void SCCPSolver::visitCallableTerminatorOperation(Operation *callable,
|
|
Operation *terminator) {
|
|
// If there are no exiting values, we have nothing to track.
|
|
if (terminator->getNumOperands() == 0)
|
|
return;
|
|
|
|
// If this callable isn't tracking any lattice state there is nothing to do.
|
|
auto latticeIt = callableLatticeState.find(callable);
|
|
if (latticeIt == callableLatticeState.end())
|
|
return;
|
|
assert(callable->getNumResults() == 0 && "expected symbol callable");
|
|
|
|
// If this terminator is not "return-like", conservatively mark all of the
|
|
// call-site results as overdefined.
|
|
auto callableResultLattices = latticeIt->second.getResultLatticeValues();
|
|
if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
|
|
for (auto &it : callableResultLattices)
|
|
it.markOverdefined();
|
|
for (Operation *call : latticeIt->second.getSymbolCalls())
|
|
markAllOverdefined(call, call->getResults());
|
|
return;
|
|
}
|
|
|
|
// Merge the terminator operands into the results.
|
|
bool anyChanged = false;
|
|
for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices))
|
|
anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]);
|
|
if (!anyChanged)
|
|
return;
|
|
|
|
// If any of the result lattices changed, update the callers.
|
|
for (Operation *call : latticeIt->second.getSymbolCalls())
|
|
for (auto it : llvm::zip(call->getResults(), callableResultLattices))
|
|
meet(call, latticeValues[std::get<0>(it)], std::get<1>(it));
|
|
}
|
|
|
|
void SCCPSolver::visitBlock(Block *block) {
|
|
// If the block is not the entry block we need to compute the lattice state
|
|
// for the block arguments. Entry block argument lattices are computed
|
|
// elsewhere, such as when visiting the parent operation.
|
|
if (!block->isEntryBlock()) {
|
|
for (int i : llvm::seq<int>(0, block->getNumArguments()))
|
|
visitBlockArgument(block, i);
|
|
}
|
|
|
|
// Visit all of the operations within the block.
|
|
for (Operation &op : *block)
|
|
visitOperation(&op);
|
|
}
|
|
|
|
void SCCPSolver::visitBlockArgument(Block *block, int i) {
|
|
BlockArgument arg = block->getArgument(i);
|
|
LatticeValue &argLattice = latticeValues[arg];
|
|
if (argLattice.isOverdefined())
|
|
return;
|
|
|
|
bool updatedLattice = false;
|
|
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
|
|
Block *pred = *it;
|
|
|
|
// We only care about this predecessor if it is going to execute.
|
|
if (!isEdgeExecutable(pred, block))
|
|
continue;
|
|
|
|
// Try to get the operand forwarded by the predecessor. If we can't reason
|
|
// about the terminator of the predecessor, mark overdefined.
|
|
Optional<OperandRange> branchOperands;
|
|
if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
|
|
branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
|
|
if (!branchOperands) {
|
|
updatedLattice = true;
|
|
argLattice.markOverdefined();
|
|
break;
|
|
}
|
|
|
|
// If the operand hasn't been resolved, it is unknown which can merge with
|
|
// anything.
|
|
auto operandLattice = latticeValues.find((*branchOperands)[i]);
|
|
if (operandLattice == latticeValues.end())
|
|
continue;
|
|
|
|
// Otherwise, meet the two lattice values.
|
|
updatedLattice |= argLattice.meet(operandLattice->second);
|
|
if (argLattice.isOverdefined())
|
|
break;
|
|
}
|
|
|
|
// If the lattice was updated, visit any executable users of the argument.
|
|
if (updatedLattice)
|
|
visitUsers(arg);
|
|
}
|
|
|
|
bool SCCPSolver::markEntryBlockExecutable(Region *region,
|
|
bool markArgsOverdefined) {
|
|
if (!region->empty()) {
|
|
if (markArgsOverdefined)
|
|
markAllOverdefined(region->front().getArguments());
|
|
return markBlockExecutable(®ion->front());
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool SCCPSolver::markBlockExecutable(Block *block) {
|
|
bool marked = executableBlocks.insert(block).second;
|
|
if (marked)
|
|
blockWorklist.push_back(block);
|
|
return marked;
|
|
}
|
|
|
|
bool SCCPSolver::isBlockExecutable(Block *block) const {
|
|
return executableBlocks.count(block);
|
|
}
|
|
|
|
void SCCPSolver::markEdgeExecutable(Block *from, Block *to) {
|
|
if (!executableEdges.insert(std::make_pair(from, to)).second)
|
|
return;
|
|
// Mark the destination as executable, and reprocess its arguments if it was
|
|
// already executable.
|
|
if (!markBlockExecutable(to)) {
|
|
for (int i : llvm::seq<int>(0, to->getNumArguments()))
|
|
visitBlockArgument(to, i);
|
|
}
|
|
}
|
|
|
|
bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const {
|
|
return executableEdges.count(std::make_pair(from, to));
|
|
}
|
|
|
|
void SCCPSolver::markOverdefined(Value value) {
|
|
latticeValues[value].markOverdefined();
|
|
}
|
|
|
|
bool SCCPSolver::isOverdefined(Value value) const {
|
|
auto it = latticeValues.find(value);
|
|
return it != latticeValues.end() && it->second.isOverdefined();
|
|
}
|
|
|
|
void SCCPSolver::meet(Operation *owner, LatticeValue &to,
|
|
const LatticeValue &from) {
|
|
if (to.meet(from))
|
|
opWorklist.push_back(owner);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SCCP Pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct SCCP : public SCCPBase<SCCP> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
void SCCP::runOnOperation() {
|
|
Operation *op = getOperation();
|
|
|
|
// Solve for SCCP constraints within nested regions.
|
|
SCCPSolver solver(op);
|
|
solver.solve();
|
|
|
|
// Cleanup any operations using the solver analysis.
|
|
solver.rewrite(&getContext(), op->getRegions());
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSCCPPass() {
|
|
return std::make_unique<SCCP>();
|
|
}
|