forked from OSchip/llvm-project
[mlir][SCCP] Add support for propagating across symbol based calls
This revision adds support for propagating constants across symbol-based callgraph edges. It uses the existing Call/CallableOpInterfaces to detect the dataflow edges, and propagates constants through arguments and out of returns. Differential Revision: https://reviews.llvm.org/D78592
This commit is contained in:
parent
7c221a7d4f
commit
a90151d67e
|
@ -86,6 +86,15 @@ public:
|
|||
/// nullptr if no valid parent symbol table could be found.
|
||||
static Operation *getNearestSymbolTable(Operation *from);
|
||||
|
||||
/// Walks all symbol table operations nested within, and including, `op`. For
|
||||
/// each symbol table operation, the provided callback is invoked with the op
|
||||
/// and a boolean signifying if the symbols within that symbol table can be
|
||||
/// treated as if all uses within the IR are visible to the caller.
|
||||
/// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
|
||||
/// within `op` are visible.
|
||||
static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
|
||||
function_ref<void(Operation *, bool)> callback);
|
||||
|
||||
/// Returns the operation registered with the given symbol name with the
|
||||
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
|
||||
/// with the 'OpTrait::SymbolTable' trait.
|
||||
|
|
|
@ -34,7 +34,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
|
|||
InterfaceMethod<[{
|
||||
Returns the callee of this call-like operation. A `callee` is either a
|
||||
reference to a symbol, via SymbolRefAttr, or a reference to a defined
|
||||
SSA value.
|
||||
SSA value. If the reference is an SSA value, the SSA value corresponds
|
||||
to a region of a lambda-like operation.
|
||||
}],
|
||||
"CallInterfaceCallable", "getCallableForCallee"
|
||||
>,
|
||||
|
|
|
@ -207,6 +207,35 @@ Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
|
|||
return from;
|
||||
}
|
||||
|
||||
/// Walks all symbol table operations nested within, and including, `op`. For
|
||||
/// each symbol table operation, the provided callback is invoked with the op
|
||||
/// and a boolean signifying if the symbols within that symbol table can be
|
||||
/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
|
||||
/// all of the symbol uses of symbols within `op` are visible.
|
||||
void SymbolTable::walkSymbolTables(
|
||||
Operation *op, bool allSymUsesVisible,
|
||||
function_ref<void(Operation *, bool)> callback) {
|
||||
bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
|
||||
if (isSymbolTable) {
|
||||
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
|
||||
allSymUsesVisible |= !symbol || symbol.isPrivate();
|
||||
} else {
|
||||
// Otherwise if 'op' is not a symbol table, any nested symbols are
|
||||
// guaranteed to be hidden.
|
||||
allSymUsesVisible = true;
|
||||
}
|
||||
|
||||
for (Region ®ion : op->getRegions())
|
||||
for (Block &block : region)
|
||||
for (Operation &nestedOp : block)
|
||||
walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
|
||||
|
||||
// If 'op' had the symbol table trait, visit it after any nested symbol
|
||||
// tables.
|
||||
if (isSymbolTable)
|
||||
callback(op, allSymUsesVisible);
|
||||
}
|
||||
|
||||
/// Returns the operation registered with the given symbol name with the
|
||||
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
|
||||
/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
|
||||
|
|
|
@ -31,29 +31,6 @@ using namespace mlir;
|
|||
// Symbol Use Tracking
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Walk all of the symbol table operations nested with 'op' along with a
|
||||
/// boolean signifying if the symbols within can be treated as if all uses are
|
||||
/// visible. The provided callback is invoked with the symbol table operation,
|
||||
/// and a boolean signaling if all of the uses within the symbol table are
|
||||
/// visible.
|
||||
static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
|
||||
function_ref<void(Operation *, bool)> callback) {
|
||||
if (op->hasTrait<OpTrait::SymbolTable>()) {
|
||||
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
|
||||
allSymUsesVisible = allSymUsesVisible || !symbol || symbol.isPrivate();
|
||||
callback(op, allSymUsesVisible);
|
||||
} else {
|
||||
// Otherwise if 'op' is not a symbol table, any nested symbols are
|
||||
// guaranteed to be hidden.
|
||||
allSymUsesVisible = true;
|
||||
}
|
||||
|
||||
for (Region ®ion : op->getRegions())
|
||||
for (Block &block : region)
|
||||
for (Operation &nested : block)
|
||||
walkSymbolTables(&nested, allSymUsesVisible, callback);
|
||||
}
|
||||
|
||||
/// Walk all of the used symbol callgraph nodes referenced with the given op.
|
||||
static void walkReferencedSymbolNodes(
|
||||
Operation *op, CallGraph &cg,
|
||||
|
@ -164,7 +141,8 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
|
|||
}
|
||||
}
|
||||
};
|
||||
walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), walkFn);
|
||||
SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
|
||||
walkFn);
|
||||
|
||||
// Drop the use information for any discardable nodes that are always live.
|
||||
for (auto &it : alwaysLiveNodes)
|
||||
|
|
|
@ -116,12 +116,56 @@ private:
|
|||
Dialect *constantDialect;
|
||||
};
|
||||
|
||||
/// This class contains various state used when computing the lattice of a
|
||||
/// callable operation.
|
||||
class CallableLatticeState {
|
||||
public:
|
||||
/// Build a lattice state with a given callable region, and a specified number
|
||||
/// of results to be initialized to the default lattice value (Unknown).
|
||||
CallableLatticeState(Region *callableRegion, unsigned numResults)
|
||||
: callableArguments(callableRegion->front().getArguments()),
|
||||
resultLatticeValues(numResults) {}
|
||||
|
||||
/// Returns the arguments to the callable region.
|
||||
Block::BlockArgListType getCallableArguments() const {
|
||||
return callableArguments;
|
||||
}
|
||||
|
||||
/// Returns the lattice value for the results of the callable region.
|
||||
MutableArrayRef<LatticeValue> getResultLatticeValues() {
|
||||
return resultLatticeValues;
|
||||
}
|
||||
|
||||
/// Add a call to this callable. This is only used if the callable defines a
|
||||
/// symbol.
|
||||
void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
|
||||
|
||||
/// Return the calls that reference this callable. This is only used
|
||||
/// if the callable defines a symbol.
|
||||
ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
|
||||
|
||||
private:
|
||||
/// The arguments of the callable region.
|
||||
Block::BlockArgListType callableArguments;
|
||||
|
||||
/// The lattice state for each of the results of this region. The return
|
||||
/// values of the callable aren't SSA values, so we need to track them
|
||||
/// separately.
|
||||
SmallVector<LatticeValue, 4> resultLatticeValues;
|
||||
|
||||
/// The calls referencing this callable if this callable defines a symbol.
|
||||
/// This removes the need to recompute symbol references during propagation.
|
||||
/// Value based references are trivial to resolve, so they can be done
|
||||
/// in-place.
|
||||
SmallVector<Operation *, 4> symbolCalls;
|
||||
};
|
||||
|
||||
/// This class represents the solver for the SCCP analysis. This class acts as
|
||||
/// the propagation engine for computing which values form constants.
|
||||
class SCCPSolver {
|
||||
public:
|
||||
/// Initialize the solver with a given set of regions.
|
||||
SCCPSolver(MutableArrayRef<Region> regions);
|
||||
/// Initialize the solver with the given top-level operation.
|
||||
SCCPSolver(Operation *op);
|
||||
|
||||
/// Run the solver until it converges.
|
||||
void solve();
|
||||
|
@ -132,6 +176,11 @@ public:
|
|||
void rewrite(MLIRContext *context, MutableArrayRef<Region> regions);
|
||||
|
||||
private:
|
||||
/// Initialize the set of symbol defining callables that can have their
|
||||
/// arguments and results tracked. 'op' is the top-level operation that SCCP
|
||||
/// is operating on.
|
||||
void initializeSymbolCallables(Operation *op);
|
||||
|
||||
/// Replace the given value with a constant if the corresponding lattice
|
||||
/// represents a constant. Returns success if the value was replaced, failure
|
||||
/// otherwise.
|
||||
|
@ -149,6 +198,13 @@ private:
|
|||
/// Visit the given operation and compute any necessary lattice state.
|
||||
void visitOperation(Operation *op);
|
||||
|
||||
/// Visit the given call operation and compute any necessary lattice state.
|
||||
void visitCallOperation(CallOpInterface op);
|
||||
|
||||
/// Visit the given callable operation and compute any necessary lattice
|
||||
/// state.
|
||||
void visitCallableOperation(Operation *op);
|
||||
|
||||
/// Visit the given operation, which defines regions, and compute any
|
||||
/// necessary lattice state. This also resolves the lattice state of both the
|
||||
/// operation results and any nested regions.
|
||||
|
@ -168,6 +224,11 @@ private:
|
|||
void visitTerminatorOperation(Operation *op,
|
||||
ArrayRef<Attribute> constantOperands);
|
||||
|
||||
/// Visit the given terminator operation that exits a callable region. These
|
||||
/// are terminators with no CFG successors.
|
||||
void visitCallableTerminatorOperation(Operation *callable,
|
||||
Operation *terminator);
|
||||
|
||||
/// Visit the given block and compute any necessary lattice state.
|
||||
void visitBlock(Block *block);
|
||||
|
||||
|
@ -235,11 +296,20 @@ private:
|
|||
|
||||
/// A worklist of operations that need to be processed.
|
||||
SmallVector<Operation *, 64> opWorklist;
|
||||
|
||||
/// The callable operations that have their argument/result state tracked.
|
||||
DenseMap<Operation *, CallableLatticeState> callableLatticeState;
|
||||
|
||||
/// A map between a call operation and the resolved symbol callable. This
|
||||
/// avoids re-resolving symbol references during propagation. Value based
|
||||
/// callables are trivial to resolve, so they can be done in-place.
|
||||
DenseMap<Operation *, Operation *> callToSymbolCallable;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
|
||||
for (Region ®ion : regions) {
|
||||
SCCPSolver::SCCPSolver(Operation *op) {
|
||||
/// Initialize the solver with the regions within this operation.
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
if (region.empty())
|
||||
continue;
|
||||
Block *entryBlock = ®ion.front();
|
||||
|
@ -251,6 +321,7 @@ SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
|
|||
// as overdefined.
|
||||
markAllOverdefined(entryBlock->getArguments());
|
||||
}
|
||||
initializeSymbolCallables(op);
|
||||
}
|
||||
|
||||
void SCCPSolver::solve() {
|
||||
|
@ -310,6 +381,73 @@ void SCCPSolver::rewrite(MLIRContext *context,
|
|||
}
|
||||
}
|
||||
|
||||
void SCCPSolver::initializeSymbolCallables(Operation *op) {
|
||||
// Initialize the set of symbol callables that can have their state tracked.
|
||||
// This tracks which symbol callable operations we can propagate within and
|
||||
// out of.
|
||||
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
|
||||
Region &symbolTableRegion = symTable->getRegion(0);
|
||||
Block *symbolTableBlock = &symbolTableRegion.front();
|
||||
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
|
||||
// We won't be able to track external callables.
|
||||
Region *callableRegion = callable.getCallableRegion();
|
||||
if (!callableRegion)
|
||||
continue;
|
||||
// We only care about symbol defining callables here.
|
||||
auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
|
||||
if (!symbol)
|
||||
continue;
|
||||
callableLatticeState.try_emplace(callable, callableRegion,
|
||||
callable.getCallableResults().size());
|
||||
|
||||
// If not all of the uses of this symbol are visible, we can't track the
|
||||
// state of the arguments.
|
||||
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
|
||||
markAllOverdefined(callableRegion->front().getArguments());
|
||||
}
|
||||
if (callableLatticeState.empty())
|
||||
return;
|
||||
|
||||
// After computing the valid callables, walk any symbol uses to check
|
||||
// for non-call references. We won't be able to track the lattice state
|
||||
// for arguments to these callables, as we can't guarantee that we can see
|
||||
// all of its calls.
|
||||
Optional<SymbolTable::UseRange> uses =
|
||||
SymbolTable::getSymbolUses(&symbolTableRegion);
|
||||
if (!uses) {
|
||||
// If we couldn't gather the symbol uses, conservatively assume that
|
||||
// we can't track information for any nested symbols.
|
||||
op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
|
||||
return;
|
||||
}
|
||||
|
||||
for (const SymbolTable::SymbolUse &use : *uses) {
|
||||
// If the use is a call, track it to avoid the need to recompute the
|
||||
// reference later.
|
||||
if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
|
||||
Operation *symCallable = callOp.resolveCallable();
|
||||
auto callableLatticeIt = callableLatticeState.find(symCallable);
|
||||
if (callableLatticeIt != callableLatticeState.end()) {
|
||||
callToSymbolCallable.try_emplace(callOp, symCallable);
|
||||
|
||||
// We only need to record the call in the lattice if it produces any
|
||||
// values.
|
||||
if (callOp.getOperation()->getNumResults())
|
||||
callableLatticeIt->second.addSymbolCall(callOp);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// This use isn't a call, so don't we know all of the callers.
|
||||
auto *symbol = SymbolTable::lookupSymbolIn(op, use.getSymbolRef());
|
||||
auto it = callableLatticeState.find(symbol);
|
||||
if (it != callableLatticeState.end())
|
||||
markAllOverdefined(it->second.getCallableArguments());
|
||||
}
|
||||
};
|
||||
SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
|
||||
walkFn);
|
||||
}
|
||||
|
||||
LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder,
|
||||
OperationFolder &folder,
|
||||
Value value) {
|
||||
|
@ -347,6 +485,16 @@ void SCCPSolver::visitOperation(Operation *op) {
|
|||
if (op->isKnownTerminator())
|
||||
visitTerminatorOperation(op, operandConstants);
|
||||
|
||||
// Process call operations. The call visitor processes result values, so we
|
||||
// can exit afterwards.
|
||||
if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
|
||||
return visitCallOperation(call);
|
||||
|
||||
// Process callable operations. These are specially handled region operations
|
||||
// that track dataflow via calls.
|
||||
if (isa<CallableOpInterface>(op))
|
||||
return visitCallableOperation(op);
|
||||
|
||||
// Process region holding operations. The region visitor processes result
|
||||
// values, so we can exit afterwards.
|
||||
if (op->getNumRegions())
|
||||
|
@ -399,6 +547,62 @@ void SCCPSolver::visitOperation(Operation *op) {
|
|||
}
|
||||
}
|
||||
|
||||
void SCCPSolver::visitCallableOperation(Operation *op) {
|
||||
// Mark the regions as executable.
|
||||
bool isTrackingLatticeState = callableLatticeState.count(op);
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
if (region.empty())
|
||||
continue;
|
||||
Block *entryBlock = ®ion.front();
|
||||
markBlockExecutable(entryBlock);
|
||||
|
||||
// If we aren't tracking lattice state for this callable, mark all of the
|
||||
// region arguments as overdefined.
|
||||
if (!isTrackingLatticeState)
|
||||
markAllOverdefined(entryBlock->getArguments());
|
||||
}
|
||||
|
||||
// TODO: Add support for non-symbol callables when necessary. If the callable
|
||||
// has non-call uses we would mark overdefined, otherwise allow for
|
||||
// propagating the return values out.
|
||||
markAllOverdefined(op, op->getResults());
|
||||
}
|
||||
|
||||
void SCCPSolver::visitCallOperation(CallOpInterface op) {
|
||||
ResultRange callResults = op.getOperation()->getResults();
|
||||
|
||||
// Resolve the callable operation for this call.
|
||||
Operation *callableOp = nullptr;
|
||||
if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
|
||||
callableOp = callableValue.getDefiningOp();
|
||||
else
|
||||
callableOp = callToSymbolCallable.lookup(op);
|
||||
|
||||
// The callable of this call can't be resolved, mark any results overdefined.
|
||||
if (!callableOp)
|
||||
return markAllOverdefined(op, callResults);
|
||||
|
||||
// If this callable is tracking state, merge the argument operands with the
|
||||
// arguments of the callable.
|
||||
auto callableLatticeIt = callableLatticeState.find(callableOp);
|
||||
if (callableLatticeIt == callableLatticeState.end())
|
||||
return markAllOverdefined(op, callResults);
|
||||
|
||||
OperandRange callOperands = op.getArgOperands();
|
||||
auto callableArgs = callableLatticeIt->second.getCallableArguments();
|
||||
for (auto it : llvm::zip(callOperands, callableArgs)) {
|
||||
BlockArgument callableArg = std::get<1>(it);
|
||||
if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)]))
|
||||
visitUsers(callableArg);
|
||||
}
|
||||
|
||||
// Merge in the lattice state for the callable results as well.
|
||||
auto callableResults = callableLatticeIt->second.getResultLatticeValues();
|
||||
for (auto it : llvm::zip(callResults, callableResults))
|
||||
meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)],
|
||||
/*from=*/std::get<1>(it));
|
||||
}
|
||||
|
||||
void SCCPSolver::visitRegionOperation(Operation *op,
|
||||
ArrayRef<Attribute> constantOperands) {
|
||||
// Check to see if we can reason about the internal control flow of this
|
||||
|
@ -509,9 +713,14 @@ void SCCPSolver::visitTerminatorOperation(
|
|||
Operation *op, ArrayRef<Attribute> constantOperands) {
|
||||
// If this operation has no successors, we treat it as an exiting terminator.
|
||||
if (op->getNumSuccessors() == 0) {
|
||||
// Check to see if the parent tracks region control flow.
|
||||
Region *parentRegion = op->getParentRegion();
|
||||
Operation *parentOp = parentRegion->getParentOp();
|
||||
|
||||
// Check to see if this is a terminator for a callable region.
|
||||
if (isa<CallableOpInterface>(parentOp))
|
||||
return visitCallableTerminatorOperation(parentOp, op);
|
||||
|
||||
// Otherwise, check to see if the parent tracks region control flow.
|
||||
auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
|
||||
if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
|
||||
return;
|
||||
|
@ -552,6 +761,42 @@ void SCCPSolver::visitTerminatorOperation(
|
|||
markEdgeExecutable(block, succ);
|
||||
}
|
||||
|
||||
void SCCPSolver::visitCallableTerminatorOperation(Operation *callable,
|
||||
Operation *terminator) {
|
||||
// If there are no exiting values, we have nothing to track.
|
||||
if (terminator->getNumOperands() == 0)
|
||||
return;
|
||||
|
||||
// If this callable isn't tracking any lattice state there is nothing to do.
|
||||
auto latticeIt = callableLatticeState.find(callable);
|
||||
if (latticeIt == callableLatticeState.end())
|
||||
return;
|
||||
assert(callable->getNumResults() == 0 && "expected symbol callable");
|
||||
|
||||
// If this terminator is not "return-like", conservatively mark all of the
|
||||
// call-site results as overdefined.
|
||||
auto callableResultLattices = latticeIt->second.getResultLatticeValues();
|
||||
if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
|
||||
for (auto &it : callableResultLattices)
|
||||
it.markOverdefined();
|
||||
for (Operation *call : latticeIt->second.getSymbolCalls())
|
||||
markAllOverdefined(call, call->getResults());
|
||||
return;
|
||||
}
|
||||
|
||||
// Merge the terminator operands into the results.
|
||||
bool anyChanged = false;
|
||||
for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices))
|
||||
anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]);
|
||||
if (!anyChanged)
|
||||
return;
|
||||
|
||||
// If any of the result lattices changed, update the callers.
|
||||
for (Operation *call : latticeIt->second.getSymbolCalls())
|
||||
for (auto it : llvm::zip(call->getResults(), callableResultLattices))
|
||||
meet(call, latticeValues[std::get<0>(it)], std::get<1>(it));
|
||||
}
|
||||
|
||||
void SCCPSolver::visitBlock(Block *block) {
|
||||
// If the block is not the entry block we need to compute the lattice state
|
||||
// for the block arguments. Entry block argument lattices are computed
|
||||
|
@ -663,7 +908,7 @@ void SCCP::runOnOperation() {
|
|||
Operation *op = getOperation();
|
||||
|
||||
// Solve for SCCP constraints within nested regions.
|
||||
SCCPSolver solver(op->getRegions());
|
||||
SCCPSolver solver(op);
|
||||
solver.solve();
|
||||
|
||||
// Cleanup any operations using the solver analysis.
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect %s -sccp -split-input-file | FileCheck %s -dump-input-on-failure
|
||||
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="module(sccp)" -split-input-file | FileCheck %s --check-prefix=NESTED -dump-input-on-failure
|
||||
|
||||
/// Check that a constant is properly propagated through the arguments and
|
||||
/// results of a private function.
|
||||
|
||||
// CHECK-LABEL: func @private(
|
||||
func @private(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i32
|
||||
// CHECK: return %[[CST]] : i32
|
||||
|
||||
return %arg0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_private(
|
||||
func @simple_private() -> i32 {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i32
|
||||
// CHECK: return %[[CST]] : i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%result = call @private(%1) : (i32) -> i32
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Check that a constant is properly propagated through the arguments and
|
||||
/// results of a visible nested function.
|
||||
|
||||
// CHECK: func @nested(
|
||||
func @nested(%arg0 : i32) -> i32 attributes { sym_visibility = "nested" } {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i32
|
||||
// CHECK: return %[[CST]] : i32
|
||||
|
||||
return %arg0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_nested(
|
||||
func @simple_nested() -> i32 {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i32
|
||||
// CHECK: return %[[CST]] : i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%result = call @nested(%1) : (i32) -> i32
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Check that non-visible nested functions do not track arguments.
|
||||
module {
|
||||
// NESTED-LABEL: module @nested_module
|
||||
module @nested_module attributes { sym_visibility = "public" } {
|
||||
|
||||
// NESTED: func @nested(
|
||||
func @nested(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "nested" } {
|
||||
// NESTED: %[[CST:.*]] = constant 1 : i32
|
||||
// NESTED: return %[[CST]], %arg0 : i32, i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
return %1, %arg0 : i32, i32
|
||||
}
|
||||
|
||||
// NESTED: func @nested_not_all_uses_visible(
|
||||
func @nested_not_all_uses_visible() -> (i32, i32) {
|
||||
// NESTED: %[[CST:.*]] = constant 1 : i32
|
||||
// NESTED: %[[CALL:.*]]:2 = call @nested
|
||||
// NESTED: return %[[CST]], %[[CALL]]#1 : i32, i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%result:2 = call @nested(%1) : (i32) -> (i32, i32)
|
||||
return %result#0, %result#1 : i32, i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Check that public functions do not track arguments.
|
||||
|
||||
// CHECK-LABEL: func @public(
|
||||
func @public(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "public" } {
|
||||
%1 = constant 1 : i32
|
||||
return %1, %arg0 : i32, i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_public(
|
||||
func @simple_public() -> (i32, i32) {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i32
|
||||
// CHECK: %[[CALL:.*]]:2 = call @public
|
||||
// CHECK: return %[[CST]], %[[CALL]]#1 : i32, i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%result:2 = call @public(%1) : (i32) -> (i32, i32)
|
||||
return %result#0, %result#1 : i32, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Check that functions with non-call users don't have arguments tracked.
|
||||
|
||||
func @callable(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "private" } {
|
||||
%1 = constant 1 : i32
|
||||
return %1, %arg0 : i32, i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @non_call_users(
|
||||
func @non_call_users() -> (i32, i32) {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i32
|
||||
// CHECK: %[[CALL:.*]]:2 = call @callable
|
||||
// CHECK: return %[[CST]], %[[CALL]]#1 : i32, i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%result:2 = call @callable(%1) : (i32) -> (i32, i32)
|
||||
return %result#0, %result#1 : i32, i32
|
||||
}
|
||||
|
||||
"live.user"() {uses = [@callable]} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
/// Check that return values are overdefined in the presence of an unknown terminator.
|
||||
|
||||
func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
|
||||
"unknown.return"(%arg0) : (i32) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unknown_terminator(
|
||||
func @unknown_terminator() -> i32 {
|
||||
// CHECK: %[[CALL:.*]] = call @callable
|
||||
// CHECK: return %[[CALL]] : i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%result = call @callable(%1) : (i32) -> i32
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Check that return values are overdefined when the constant conflicts.
|
||||
|
||||
func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
|
||||
"unknown.return"(%arg0) : (i32) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conflicting_constant(
|
||||
func @conflicting_constant() -> (i32, i32) {
|
||||
// CHECK: %[[CALL1:.*]] = call @callable
|
||||
// CHECK: %[[CALL2:.*]] = call @callable
|
||||
// CHECK: return %[[CALL1]], %[[CALL2]] : i32, i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%2 = constant 2 : i32
|
||||
%result = call @callable(%1) : (i32) -> i32
|
||||
%result2 = call @callable(%2) : (i32) -> i32
|
||||
return %result, %result2 : i32, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Check that return values are overdefined when the constant conflicts with a
|
||||
/// non-constant.
|
||||
|
||||
func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
|
||||
"unknown.return"(%arg0) : (i32) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conflicting_constant(
|
||||
func @conflicting_constant(%arg0 : i32) -> (i32, i32) {
|
||||
// CHECK: %[[CALL1:.*]] = call @callable
|
||||
// CHECK: %[[CALL2:.*]] = call @callable
|
||||
// CHECK: return %[[CALL1]], %[[CALL2]] : i32, i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%result = call @callable(%1) : (i32) -> i32
|
||||
%result2 = call @callable(%arg0) : (i32) -> i32
|
||||
return %result, %result2 : i32, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Check a more complex interaction with calls and control flow.
|
||||
|
||||
// CHECK-LABEL: func @complex_inner_if(
|
||||
func @complex_inner_if(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
|
||||
// CHECK-DAG: %[[TRUE:.*]] = constant 1 : i1
|
||||
// CHECK-DAG: %[[CST:.*]] = constant 1 : i32
|
||||
// CHECK: cond_br %[[TRUE]], ^bb1
|
||||
|
||||
%cst_20 = constant 20 : i32
|
||||
%cond = cmpi "ult", %arg0, %cst_20 : i32
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
|
||||
^bb1:
|
||||
// CHECK: ^bb1:
|
||||
// CHECK: return %[[CST]] : i32
|
||||
|
||||
%cst_1 = constant 1 : i32
|
||||
return %cst_1 : i32
|
||||
|
||||
^bb2:
|
||||
%cst_1_2 = constant 1 : i32
|
||||
%arg_inc = addi %arg0, %cst_1_2 : i32
|
||||
return %arg_inc : i32
|
||||
}
|
||||
|
||||
func @complex_cond() -> i1
|
||||
|
||||
// CHECK-LABEL: func @complex_callee(
|
||||
func @complex_callee(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i32
|
||||
|
||||
%loop_cond = call @complex_cond() : () -> i1
|
||||
cond_br %loop_cond, ^bb1, ^bb2
|
||||
|
||||
^bb1:
|
||||
// CHECK: ^bb1:
|
||||
// CHECK-NEXT: return %[[CST]] : i32
|
||||
return %arg0 : i32
|
||||
|
||||
^bb2:
|
||||
// CHECK: ^bb2:
|
||||
// CHECK: call @complex_inner_if(%[[CST]]) : (i32) -> i32
|
||||
// CHECK: call @complex_callee(%[[CST]]) : (i32) -> i32
|
||||
// CHECK: return %[[CST]] : i32
|
||||
|
||||
%updated_arg = call @complex_inner_if(%arg0) : (i32) -> i32
|
||||
%res = call @complex_callee(%updated_arg) : (i32) -> i32
|
||||
return %res : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @complex_caller(
|
||||
func @complex_caller(%arg0 : i32) -> i32 {
|
||||
// CHECK: %[[CST:.*]] = constant 1 : i32
|
||||
// CHECK: return %[[CST]] : i32
|
||||
|
||||
%1 = constant 1 : i32
|
||||
%result = call @complex_callee(%1) : (i32) -> i32
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Check that non-symbol defining callables currently go to overdefined.
|
||||
|
||||
// CHECK-LABEL: func @non_symbol_defining_callable
|
||||
func @non_symbol_defining_callable() -> i32 {
|
||||
// CHECK: %[[RES:.*]] = call_indirect
|
||||
// CHECK: return %[[RES]] : i32
|
||||
|
||||
%fn = "test.functional_region_op"() ({
|
||||
%1 = constant 1 : i32
|
||||
"test.return"(%1) : (i32) -> ()
|
||||
}) : () -> (() -> i32)
|
||||
%res = call_indirect %fn() : () -> (i32)
|
||||
return %res : i32
|
||||
}
|
|
@ -1090,7 +1090,7 @@ def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TestRegionBuilderOp : TEST_Op<"region_builder">;
|
||||
def TestReturnOp : TEST_Op<"return", [Terminator]>,
|
||||
def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]>,
|
||||
Arguments<(ins Variadic<AnyType>)>;
|
||||
def TestCastOp : TEST_Op<"cast">,
|
||||
Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>;
|
||||
|
|
Loading…
Reference in New Issue