[mlir] Add Dead Code Analysis

This patch implements the analysis state classes needed for sparse data-flow analysis and implements a dead-code analysis using those states to determine liveness of blocks, control-flow edges, region predecessors, and function callsites.

Depends on D126751

Reviewed By: rriddle, phisiart

Differential Revision: https://reviews.llvm.org/D127064
This commit is contained in:
Mogball 2022-06-23 19:02:45 +00:00
parent 0586d1cac2
commit c095afcba6
13 changed files with 1332 additions and 23 deletions

View File

@ -0,0 +1,66 @@
//===- ConstantPropagationAnalysis.h - Constant propagation analysis ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements constant propagation analysis. In this file are defined
// the lattice value class that represents constant values in the program and
// a sparse constant propagation analysis that uses operation folders to
// speculate about constant values in the program.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTPROPAGATIONANALYSIS_H
#define MLIR_ANALYSIS_DATAFLOW_CONSTANTPROPAGATIONANALYSIS_H
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
namespace mlir {
namespace dataflow {
//===----------------------------------------------------------------------===//
// ConstantValue
//===----------------------------------------------------------------------===//
/// This lattice value represents a known constant value of a lattice.
class ConstantValue {
public:
/// Construct a constant value with a known constant.
ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr)
: constant(knownValue), dialect(dialect) {}
/// Get the constant value. Returns null if no value was determined.
Attribute getConstantValue() const { return constant; }
/// Get the dialect instance that can be used to materialize the constant.
Dialect *getConstantDialect() const { return dialect; }
/// Compare the constant values.
bool operator==(const ConstantValue &rhs) const {
return constant == rhs.constant;
}
/// The union with another constant value is null if they are different, and
/// the same if they are the same.
static ConstantValue join(const ConstantValue &lhs,
const ConstantValue &rhs) {
return lhs == rhs ? lhs : ConstantValue();
}
/// Print the constant value.
void print(raw_ostream &os) const;
private:
/// The constant value.
Attribute constant;
/// An dialect instance that can be used to materialize the constant.
Dialect *dialect;
};
} // end namespace dataflow
} // end namespace mlir
#endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTPROPAGATIONANALYSIS_H

View File

@ -0,0 +1,233 @@
//===- DeadCodeAnalysis.h - Dead code analysis ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements dead code analysis using the data-flow analysis
// framework. This analysis uses the results of constant propagation to
// determine live blocks, control-flow edges, and control-flow predecessors.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_DATAFLOW_DEADCODEANALYSIS_H
#define MLIR_ANALYSIS_DATAFLOW_DEADCODEANALYSIS_H
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
class CallOpInterface;
class CallableOpInterface;
class BranchOpInterface;
class RegionBranchOpInterface;
namespace dataflow {
//===----------------------------------------------------------------------===//
// Executable
//===----------------------------------------------------------------------===//
/// This is a simple analysis state that represents whether the associated
/// program point (either a block or a control-flow edge) is live.
class Executable : public AnalysisState {
public:
using AnalysisState::AnalysisState;
/// The state is initialized by default.
bool isUninitialized() const override { return false; }
/// The state is always initialized.
ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
/// Set the state of the program point to live.
ChangeResult setToLive();
/// Get whether the program point is live.
bool isLive() const { return live; }
/// Print the liveness.
void print(raw_ostream &os) const override;
/// When the state of the program point is changed to live, re-invoke
/// subscribed analyses on the operations in the block and on the block
/// itself.
void onUpdate(DataFlowSolver *solver) const override;
/// Subscribe an analysis to changes to the liveness.
void blockContentSubscribe(DataFlowAnalysis *analysis) {
subscribers.insert(analysis);
}
private:
/// Whether the program point is live. Optimistically assume that the program
/// point is dead.
bool live = false;
/// A set of analyses that should be updated when this state changes.
SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
SmallPtrSet<DataFlowAnalysis *, 4>>
subscribers;
};
//===----------------------------------------------------------------------===//
// PredecessorState
//===----------------------------------------------------------------------===//
/// This analysis state represents a set of live control-flow "predecessors" of
/// a program point (either an operation or a block), which are the last
/// operations along all execution paths that pass through this point.
///
/// For example, in dead-code analysis, an operation with region control-flow
/// can be the predecessor of a region's entry block or itself, the exiting
/// terminator of a region can be the predecessor of the parent operation or
/// another region's entry block, the callsite of a callable operation can be
/// the predecessor to its entry block, and the exiting terminator or a callable
/// operation can be the predecessor of the call operation.
///
/// The state can indicate that it is underdefined, meaning that not all live
/// control-flow predecessors can be known.
class PredecessorState : public AnalysisState {
public:
using AnalysisState::AnalysisState;
/// The state is initialized by default.
bool isUninitialized() const override { return false; }
/// The state is always initialized.
ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
/// Print the known predecessors.
void print(raw_ostream &os) const override;
/// Returns true if all predecessors are known.
bool allPredecessorsKnown() const { return allKnown; }
/// Indicate that there are potentially unknown predecessors.
ChangeResult setHasUnknownPredecessors() {
return std::exchange(allKnown, false) ? ChangeResult::Change
: ChangeResult::NoChange;
}
/// Get the known predecessors.
ArrayRef<Operation *> getKnownPredecessors() const {
return knownPredecessors.getArrayRef();
}
/// Add a known predecessor.
ChangeResult join(Operation *predecessor) {
return knownPredecessors.insert(predecessor) ? ChangeResult::Change
: ChangeResult::NoChange;
}
private:
/// Whether all predecessors are known. Optimistically assume that we know
/// all predecessors.
bool allKnown = true;
/// The known control-flow predecessors of this program point.
SetVector<Operation *, SmallVector<Operation *, 4>,
SmallPtrSet<Operation *, 4>>
knownPredecessors;
};
//===----------------------------------------------------------------------===//
// CFGEdge
//===----------------------------------------------------------------------===//
/// This program point represents a control-flow edge between a block and one
/// of its successors.
class CFGEdge
: public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
public:
using Base::Base;
/// Get the block from which the edge originates.
Block *getFrom() const { return getValue().first; }
/// Get the target block.
Block *getTo() const { return getValue().second; }
/// Print the blocks between the control-flow edge.
void print(raw_ostream &os) const override;
/// Get a fused location of both blocks.
Location getLoc() const override;
};
//===----------------------------------------------------------------------===//
// DeadCodeAnalysis
//===----------------------------------------------------------------------===//
/// Dead code analysis analyzes control-flow, as understood by
/// `RegionBranchOpInterface` and `BranchOpInterface`, and the callgraph, as
/// understood by `CallableOpInterface` and `CallOpInterface`.
///
/// This analysis uses known constant values of operands to determine the
/// liveness of each block and each edge between a block and its predecessors.
/// For region control-flow, this analysis determines the predecessor operations
/// for region entry blocks and region control-flow operations. For the
/// callgraph, this analysis determines the callsites and live returns of every
/// function.
class DeadCodeAnalysis : public DataFlowAnalysis {
public:
explicit DeadCodeAnalysis(DataFlowSolver &solver);
/// Initialize the analysis by visiting every operation with potential
/// control-flow semantics.
LogicalResult initialize(Operation *top) override;
/// Visit an operation with control-flow semantics and deduce which of its
/// successors are live.
LogicalResult visit(ProgramPoint point) override;
private:
/// Find and mark symbol callables with potentially unknown callsites as
/// having overdefined predecessors. `top` is the top-level operation that the
/// analysis is operating on.
void initializeSymbolCallables(Operation *top);
/// Recursively Initialize the analysis on nested regions.
LogicalResult initializeRecursively(Operation *op);
/// Visit the given call operation and compute any necessary lattice state.
void visitCallOperation(CallOpInterface call);
/// Visit the given branch operation with successors and try to determine
/// which are live from the current block.
void visitBranchOperation(BranchOpInterface branch);
/// Visit the given region branch operation, which defines regions, and
/// compute any necessary lattice state. This also resolves the lattice state
/// of both the operation results and any nested regions.
void visitRegionBranchOperation(RegionBranchOpInterface branch);
/// Visit the given terminator operation that exits a region under an
/// operation with control-flow semantics. These are terminators with no CFG
/// successors.
void visitRegionTerminator(Operation *op, RegionBranchOpInterface branch);
/// Visit the given terminator operation that exits a callable region. These
/// are terminators with no CFG successors.
void visitCallableTerminator(Operation *op, CallableOpInterface callable);
/// Mark the edge between `from` and `to` as executable.
void markEdgeLive(Block *from, Block *to);
/// Mark the entry blocks of the operation as executable.
void markEntryBlocksLive(Operation *op);
/// Get the constant values of the operands of the operation. Returns none if
/// any of the operand lattices are uninitialized.
Optional<SmallVector<Attribute>> getOperandValues(Operation *op);
/// A symbol table used for O(1) symbol lookups during simplification.
SymbolTableCollection symbolTable;
};
} // end namespace dataflow
} // end namespace mlir
#endif // MLIR_ANALYSIS_DATAFLOW_DEADCODEANALYSIS_H

View File

@ -0,0 +1,185 @@
//===- SparseAnalysis.h - Sparse data-flow analysis -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements sparse data-flow analysis using the data-flow analysis
// framework. The analysis is forward and conditional and uses the results of
// dead code analysis to prune dead code during the analysis.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
#define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
#include "mlir/Analysis/DataFlowFramework.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
namespace dataflow {
//===----------------------------------------------------------------------===//
// AbstractSparseLattice
//===----------------------------------------------------------------------===//
/// This class represents an abstract lattice. A lattice contains information
/// about an SSA value and is what's propagated across the IR by sparse
/// data-flow analysis.
class AbstractSparseLattice : public AnalysisState {
public:
/// Lattices can only be created for values.
AbstractSparseLattice(Value value) : AnalysisState(value) {}
/// Join the information contained in 'rhs' into this lattice. Returns
/// if the value of the lattice changed.
virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
/// Returns true if the lattice element is at fixpoint and further calls to
/// `join` will not update the value of the element.
virtual bool isAtFixpoint() const = 0;
/// Mark the lattice element as having reached a pessimistic fixpoint. This
/// means that the lattice may potentially have conflicting value states, and
/// only the most conservative value should be relied on.
virtual ChangeResult markPessimisticFixpoint() = 0;
/// When the lattice gets updated, propagate an update to users of the value
/// using its use-def chain to subscribed analyses.
void onUpdate(DataFlowSolver *solver) const override;
/// Subscribe an analysis to updates of the lattice. When the lattice changes,
/// subscribed analyses are re-invoked on all users of the value. This is
/// more efficient than relying on the dependency map.
void useDefSubscribe(DataFlowAnalysis *analysis) {
useDefSubscribers.insert(analysis);
}
private:
/// A set of analyses that should be updated when this lattice changes.
SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
SmallPtrSet<DataFlowAnalysis *, 4>>
useDefSubscribers;
};
//===----------------------------------------------------------------------===//
// Lattice
//===----------------------------------------------------------------------===//
/// This class represents a lattice holding a specific value of type `ValueT`.
/// Lattice values (`ValueT`) are required to adhere to the following:
///
/// * static ValueT join(const ValueT &lhs, const ValueT &rhs);
/// - This method conservatively joins the information held by `lhs`
/// and `rhs` into a new value. This method is required to be monotonic.
/// * bool operator==(const ValueT &rhs) const;
///
template <typename ValueT>
class Lattice : public AbstractSparseLattice {
public:
using AbstractSparseLattice::AbstractSparseLattice;
/// Get a lattice element with a known value.
Lattice(const ValueT &knownValue = ValueT())
: AbstractSparseLattice(Value()), knownValue(knownValue) {}
/// Return the value held by this lattice. This requires that the value is
/// initialized.
ValueT &getValue() {
assert(!isUninitialized() && "expected known lattice element");
return *optimisticValue;
}
const ValueT &getValue() const {
return const_cast<Lattice<ValueT> *>(this)->getValue();
}
/// Returns true if the value of this lattice hasn't yet been initialized.
bool isUninitialized() const override { return !optimisticValue.hasValue(); }
/// Force the initialization of the element by setting it to its pessimistic
/// fixpoint.
ChangeResult defaultInitialize() override {
return markPessimisticFixpoint();
}
/// Returns true if the lattice has reached a fixpoint. A fixpoint is when
/// the information optimistically assumed to be true is the same as the
/// information known to be true.
bool isAtFixpoint() const override { return optimisticValue == knownValue; }
/// Join the information contained in the 'rhs' lattice into this
/// lattice. Returns if the state of the current lattice changed.
ChangeResult join(const AbstractSparseLattice &rhs) override {
const Lattice<ValueT> &rhsLattice =
static_cast<const Lattice<ValueT> &>(rhs);
// If we are at a fixpoint, or rhs is uninitialized, there is nothing to do.
if (isAtFixpoint() || rhsLattice.isUninitialized())
return ChangeResult::NoChange;
// Join the rhs value into this lattice.
return join(rhsLattice.getValue());
}
/// Join the information contained in the 'rhs' value into this
/// lattice. Returns if the state of the current lattice changed.
ChangeResult join(const ValueT &rhs) {
// If the current lattice is uninitialized, copy the rhs value.
if (isUninitialized()) {
optimisticValue = rhs;
return ChangeResult::Change;
}
// Otherwise, join rhs with the current optimistic value.
ValueT newValue = ValueT::join(*optimisticValue, rhs);
assert(ValueT::join(newValue, *optimisticValue) == newValue &&
"expected `join` to be monotonic");
assert(ValueT::join(newValue, rhs) == newValue &&
"expected `join` to be monotonic");
// Update the current optimistic value if something changed.
if (newValue == optimisticValue)
return ChangeResult::NoChange;
optimisticValue = newValue;
return ChangeResult::Change;
}
/// Mark the lattice element as having reached a pessimistic fixpoint. This
/// means that the lattice may potentially have conflicting value states,
/// and only the conservatively known value state should be relied on.
ChangeResult markPessimisticFixpoint() override {
if (isAtFixpoint())
return ChangeResult::NoChange;
// For this fixed point, we take whatever we knew to be true and set that
// to our optimistic value.
optimisticValue = knownValue;
return ChangeResult::Change;
}
/// Print the lattice element.
void print(raw_ostream &os) const override {
os << "[";
knownValue.print(os);
os << ", ";
if (optimisticValue)
optimisticValue->print(os);
else
os << "<NULL>";
os << "]";
}
private:
/// The value that is conservatively known to be true.
ValueT knownValue;
/// The currently computed value that is optimistically assumed to be true,
/// or None if the lattice element is uninitialized.
Optional<ValueT> optimisticValue;
};
} // end namespace dataflow
} // end namespace mlir
#endif // MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H

View File

@ -22,34 +22,17 @@
#ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H
#define MLIR_ANALYSIS_DATAFLOWANALYSIS_H
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
#include "llvm/Support/Allocator.h"
/// TODO: Remove this file when SCCP and integer range analysis have been ported
/// to the new framework.
namespace mlir {
//===----------------------------------------------------------------------===//
// ChangeResult
//===----------------------------------------------------------------------===//
/// A result type used to indicate if a change happened. Boolean operations on
/// ChangeResult behave as though `Change` is truthy.
enum class ChangeResult {
NoChange,
Change,
};
inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
return lhs == ChangeResult::Change ? lhs : rhs;
}
inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
lhs = lhs | rhs;
return lhs;
}
inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
return lhs == ChangeResult::NoChange ? lhs : rhs;
}
//===----------------------------------------------------------------------===//
// AbstractLatticeElement
//===----------------------------------------------------------------------===//

View File

@ -16,7 +16,6 @@
#ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
#define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/StorageUniquer.h"
#include "llvm/ADT/SetVector.h"
@ -25,6 +24,27 @@
namespace mlir {
//===----------------------------------------------------------------------===//
// ChangeResult
//===----------------------------------------------------------------------===//
/// A result type used to indicate if a change happened. Boolean operations on
/// ChangeResult behave as though `Change` is truthy.
enum class ChangeResult {
NoChange,
Change,
};
inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
return lhs == ChangeResult::Change ? lhs : rhs;
}
inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
lhs = lhs | rhs;
return lhs;
}
inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
return lhs == ChangeResult::NoChange ? lhs : rhs;
}
/// Forward declare the analysis state class.
class AnalysisState;
@ -137,6 +157,12 @@ struct ProgramPoint
using ParentTy::PointerUnion;
/// Allow implicit conversion from the parent type.
ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
/// Allow implicit conversions from operation wrappers.
/// TODO: For Windows only. Find a better solution.
template <typename OpT, typename = typename std::enable_if_t<
std::is_convertible<OpT, Operation *>::value &&
!std::is_same<OpT, Operation *>::value>>
ProgramPoint(OpT op) : ParentTy(op) {}
/// Print the program point.
void print(raw_ostream &os) const;
@ -180,7 +206,7 @@ public:
/// does not exist.
template <typename StateT, typename PointT>
const StateT *lookupState(PointT point) const {
auto it = analysisStates.find({point, TypeID::get<StateT>()});
auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
if (it == analysisStates.end())
return nullptr;
return static_cast<const StateT *>(it->second.get());

View File

@ -9,6 +9,10 @@ set(LLVM_OPTIONAL_SOURCES
SliceAnalysis.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
DataFlow/SparseAnalysis.cpp
)
add_mlir_library(MLIRAnalysis
@ -24,6 +28,10 @@ add_mlir_library(MLIRAnalysis
AliasAnalysis/LocalAliasAnalysis.cpp
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
DataFlow/SparseAnalysis.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis

View File

@ -0,0 +1,22 @@
//===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
using namespace mlir;
using namespace mlir::dataflow;
//===----------------------------------------------------------------------===//
// ConstantValue
//===----------------------------------------------------------------------===//
void ConstantValue::print(raw_ostream &os) const {
if (constant)
return constant.print(os);
os << "<NO VALUE>";
}

View File

@ -0,0 +1,394 @@
//===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
using namespace mlir;
using namespace mlir::dataflow;
//===----------------------------------------------------------------------===//
// Executable
//===----------------------------------------------------------------------===//
ChangeResult Executable::setToLive() {
if (live)
return ChangeResult::NoChange;
live = true;
return ChangeResult::Change;
}
void Executable::print(raw_ostream &os) const {
os << (live ? "live" : "dead");
}
void Executable::onUpdate(DataFlowSolver *solver) const {
if (auto *block = point.dyn_cast<Block *>()) {
// Re-invoke the analyses on the block itself.
for (DataFlowAnalysis *analysis : subscribers)
solver->enqueue({block, analysis});
// Re-invoke the analyses on all operations in the block.
for (DataFlowAnalysis *analysis : subscribers)
for (Operation &op : *block)
solver->enqueue({&op, analysis});
} else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
// Re-invoke the analysis on the successor block.
if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
for (DataFlowAnalysis *analysis : subscribers)
solver->enqueue({edge->getTo(), analysis});
}
}
}
//===----------------------------------------------------------------------===//
// PredecessorState
//===----------------------------------------------------------------------===//
void PredecessorState::print(raw_ostream &os) const {
if (allPredecessorsKnown())
os << "(all) ";
os << "predecessors:\n";
for (Operation *op : getKnownPredecessors())
os << " " << *op << "\n";
}
//===----------------------------------------------------------------------===//
// CFGEdge
//===----------------------------------------------------------------------===//
Location CFGEdge::getLoc() const {
return FusedLoc::get(
getFrom()->getParent()->getContext(),
{getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
}
void CFGEdge::print(raw_ostream &os) const {
getFrom()->print(os);
os << "\n -> \n";
getTo()->print(os);
}
//===----------------------------------------------------------------------===//
// DeadCodeAnalysis
//===----------------------------------------------------------------------===//
DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
registerPointKind<CFGEdge>();
}
LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
// Mark the top-level blocks as executable.
for (Region &region : top->getRegions()) {
if (region.empty())
continue;
auto *state = getOrCreate<Executable>(&region.front());
propagateIfChanged(state, state->setToLive());
}
// Mark as overdefined the predecessors of symbol callables with potentially
// unknown predecessors.
initializeSymbolCallables(top);
return initializeRecursively(top);
}
void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
Region &symbolTableRegion = symTable->getRegion(0);
Block *symbolTableBlock = &symbolTableRegion.front();
bool foundSymbolCallable = false;
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
Region *callableRegion = callable.getCallableRegion();
if (!callableRegion)
continue;
auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
if (!symbol)
continue;
// Public symbol callables or those for which we can't see all uses have
// potentially unknown callsites.
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
auto *state = getOrCreate<PredecessorState>(callable);
propagateIfChanged(state, state->setHasUnknownPredecessors());
}
foundSymbolCallable = true;
}
// Exit early if no eligible symbol callables were found in the table.
if (!foundSymbolCallable)
return;
// Walk the symbol table to check for non-call uses of symbols.
Optional<SymbolTable::UseRange> uses =
SymbolTable::getSymbolUses(&symbolTableRegion);
if (!uses) {
// If we couldn't gather the symbol uses, conservatively assume that
// we can't track information for any nested symbols.
return top->walk([&](CallableOpInterface callable) {
auto *state = getOrCreate<PredecessorState>(callable);
propagateIfChanged(state, state->setHasUnknownPredecessors());
});
}
for (const SymbolTable::SymbolUse &use : *uses) {
if (isa<CallOpInterface>(use.getUser()))
continue;
// If a callable symbol has a non-call use, then we can't be guaranteed to
// know all callsites.
Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
auto *state = getOrCreate<PredecessorState>(symbol);
propagateIfChanged(state, state->setHasUnknownPredecessors());
}
};
SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
walkFn);
}
LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
// Initialize the analysis by visiting every op with control-flow semantics.
if (op->getNumRegions() || op->getNumSuccessors() ||
op->hasTrait<OpTrait::IsTerminator>() || isa<CallOpInterface>(op)) {
// When the liveness of the parent block changes, make sure to re-invoke the
// analysis on the op.
if (op->getBlock())
getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
// Visit the op.
if (failed(visit(op)))
return failure();
}
// Recurse on nested operations.
for (Region &region : op->getRegions())
for (Operation &op : region.getOps())
if (failed(initializeRecursively(&op)))
return failure();
return success();
}
void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
auto *state = getOrCreate<Executable>(to);
propagateIfChanged(state, state->setToLive());
auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
propagateIfChanged(edgeState, edgeState->setToLive());
}
void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
auto *state = getOrCreate<Executable>(&region.front());
propagateIfChanged(state, state->setToLive());
}
}
LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
if (point.is<Block *>())
return success();
auto *op = point.dyn_cast<Operation *>();
if (!op)
return emitError(point.getLoc(), "unknown program point kind");
// If the parent block is not executable, there is nothing to do.
if (!getOrCreate<Executable>(op->getBlock())->isLive())
return success();
// We have a live call op. Add this as a live predecessor of the callee.
if (auto call = dyn_cast<CallOpInterface>(op))
visitCallOperation(call);
// Visit the regions.
if (op->getNumRegions()) {
// Check if we can reason about the region control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
visitRegionBranchOperation(branch);
// Check if this is a callable operation.
} else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
// If the callsites could not be resolved or are known to be non-empty,
// mark the callable as executable.
if (!callsites->allPredecessorsKnown() ||
!callsites->getKnownPredecessors().empty())
markEntryBlocksLive(callable);
// Otherwise, conservatively mark all entry blocks as executable.
} else {
markEntryBlocksLive(op);
}
}
if (op->hasTrait<OpTrait::IsTerminator>() && !op->getNumSuccessors()) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
// Visit the exiting terminator of a region.
visitRegionTerminator(op, branch);
} else if (auto callable =
dyn_cast<CallableOpInterface>(op->getParentOp())) {
// Visit the exiting terminator of a callable.
visitCallableTerminator(op, callable);
}
}
// Visit the successors.
if (op->getNumSuccessors()) {
// Check if we can reason about the control-flow.
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
visitBranchOperation(branch);
// Otherwise, conservatively mark all successors as exectuable.
} else {
for (Block *successor : op->getSuccessors())
markEdgeLive(op->getBlock(), successor);
}
}
return success();
}
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
Operation *callableOp = nullptr;
if (Value callableValue = call.getCallableForCallee().dyn_cast<Value>())
callableOp = callableValue.getDefiningOp();
else
callableOp = call.resolveCallable(&symbolTable);
// A call to a externally-defined callable has unknown predecessors.
const auto isExternalCallable = [](Operation *op) {
if (auto callable = dyn_cast<CallableOpInterface>(op))
return !callable.getCallableRegion();
return false;
};
// TODO: Add support for non-symbol callables when necessary. If the
// callable has non-call uses we would mark as having reached pessimistic
// fixpoint, otherwise allow for propagating the return values out.
if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
!isExternalCallable(callableOp)) {
// Add the live callsite.
auto *callsites = getOrCreate<PredecessorState>(callableOp);
propagateIfChanged(callsites, callsites->join(call));
} else {
// Mark this call op's predecessors as overdefined.
auto *predecessors = getOrCreate<PredecessorState>(call);
propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
}
}
/// Get the constant values of the operands of an operation. If any of the
/// constant value lattices are uninitialized, return none to indicate the
/// analysis should bail out.
static Optional<SmallVector<Attribute>> getOperandValuesImpl(
Operation *op,
function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
SmallVector<Attribute> operands;
operands.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
const Lattice<ConstantValue> *cv = getLattice(operand);
// If any of the operands' values are uninitialized, bail out.
if (cv->isUninitialized())
return {};
operands.push_back(cv->getValue().getConstantValue());
}
return operands;
}
Optional<SmallVector<Attribute>>
DeadCodeAnalysis::getOperandValues(Operation *op) {
return getOperandValuesImpl(op, [&](Value value) {
auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
lattice->useDefSubscribe(this);
return lattice;
});
}
void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
// Try to deduce a single successor for the branch.
Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
return;
if (Block *successor = branch.getSuccessorForOperands(*operands)) {
markEdgeLive(branch->getBlock(), successor);
} else {
// Otherwise, mark all successors as executable and outgoing edges.
for (Block *successor : branch->getSuccessors())
markEdgeLive(branch->getBlock(), successor);
}
}
void DeadCodeAnalysis::visitRegionBranchOperation(
RegionBranchOpInterface branch) {
// Try to deduce which regions are executable.
Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
return;
SmallVector<RegionSuccessor> successors;
branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
for (const RegionSuccessor &successor : successors) {
// Mark the entry block as executable.
Region *region = successor.getSuccessor();
assert(region && "expected a region successor");
auto *state = getOrCreate<Executable>(&region->front());
propagateIfChanged(state, state->setToLive());
// Add the parent op as a predecessor.
auto *predecessors = getOrCreate<PredecessorState>(&region->front());
propagateIfChanged(predecessors, predecessors->join(branch));
}
}
void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
RegionBranchOpInterface branch) {
Optional<SmallVector<Attribute>> operands = getOperandValues(op);
if (!operands)
return;
SmallVector<RegionSuccessor> successors;
branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
*operands, successors);
// Mark successor region entry blocks as executable and add this op to the
// list of predecessors.
for (const RegionSuccessor &successor : successors) {
PredecessorState *predecessors;
if (Region *region = successor.getSuccessor()) {
auto *state = getOrCreate<Executable>(&region->front());
propagateIfChanged(state, state->setToLive());
predecessors = getOrCreate<PredecessorState>(&region->front());
} else {
// Add this terminator as a predecessor to the parent op.
predecessors = getOrCreate<PredecessorState>(branch);
}
propagateIfChanged(predecessors, predecessors->join(op));
}
}
void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
CallableOpInterface callable) {
// If there are no exiting values, we have nothing to do.
if (op->getNumOperands() == 0)
return;
// Add as predecessors to all callsites this return op.
auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
for (Operation *predecessor : callsites->getKnownPredecessors()) {
assert(isa<CallOpInterface>(predecessor));
auto *predecessors = getOrCreate<PredecessorState>(predecessor);
if (canResolve) {
propagateIfChanged(predecessors, predecessors->join(op));
} else {
// If the terminator is not a return-like, then conservatively assume we
// can't resolve the predecessor.
propagateIfChanged(predecessors,
predecessors->setHasUnknownPredecessors());
}
}
}

View File

@ -0,0 +1,23 @@
//===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
using namespace mlir;
using namespace mlir::dataflow;
//===----------------------------------------------------------------------===//
// AbstractSparseLattice
//===----------------------------------------------------------------------===//
void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
// Push all users of the value to the queue.
for (Operation *user : point.get<Value>().getUsers())
for (DataFlowAnalysis *analysis : useDefSubscribers)
solver->enqueue({user, analysis});
}

View File

@ -0,0 +1,248 @@
// RUN: mlir-opt -test-dead-code-analysis 2>&1 %s | FileCheck %s
// CHECK: test_cfg:
// CHECK: region #0
// CHECK: ^bb0 = live
// CHECK: ^bb1 = live
// CHECK: from ^bb1 = live
// CHECK: from ^bb0 = live
// CHECK: ^bb2 = live
// CHECK: from ^bb1 = live
func.func @test_cfg(%cond: i1) -> ()
attributes {tag = "test_cfg"} {
cf.br ^bb1
^bb1:
cf.cond_br %cond, ^bb1, ^bb2
^bb2:
return
}
func.func @test_region_control_flow(%cond: i1, %arg0: i64, %arg1: i64) -> () {
// CHECK: test_if:
// CHECK: region #0
// CHECK: region_preds: (all) predecessors:
// CHECK: scf.if
// CHECK: region #1
// CHECK: region_preds: (all) predecessors:
// CHECK: scf.if
// CHECK: op_preds: (all) predecessors:
// CHECK: scf.yield {then}
// CHECK: scf.yield {else}
scf.if %cond {
scf.yield {then}
} else {
scf.yield {else}
} {tag = "test_if"}
// test_while:
// region #0
// region_preds: (all) predecessors:
// scf.while
// scf.yield
// region #1
// region_preds: (all) predecessors:
// scf.condition
// op_preds: (all) predecessors:
// scf.condition
%c2_i64 = arith.constant 2 : i64
%0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) {
%1 = arith.cmpi slt, %arg2, %arg1 : i64
scf.condition(%1) %arg2, %arg2 : i64, i64
} do {
^bb0(%arg2: i64, %arg3: i64):
%1 = arith.muli %arg3, %c2_i64 : i64
scf.yield %1 : i64
} attributes {tag = "test_while"}
return
}
// CHECK: foo:
// CHECK: region #0
// CHECK: ^bb0 = live
// CHECK: op_preds: (all) predecessors:
// CHECK: func.call @foo(%{{.*}}) {tag = "a"}
// CHECK: func.call @foo(%{{.*}}) {tag = "b"}
func.func private @foo(%arg0: i32) -> i32
attributes {tag = "foo"} {
return {a} %arg0 : i32
}
// CHECK: bar:
// CHECK: region #0
// CHECK: ^bb0 = live
// CHECK: op_preds: predecessors:
// CHECK: func.call @bar(%{{.*}}) {tag = "c"}
func.func @bar(%cond: i1) -> i32
attributes {tag = "bar"} {
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
%c0 = arith.constant 0 : i32
return {b} %c0 : i32
^bb2:
%c1 = arith.constant 1 : i32
return {c} %c1 : i32
}
// CHECK: baz
// CHECK: op_preds: (all) predecessors:
func.func private @baz(i32) -> i32 attributes {tag = "baz"}
func.func @test_callgraph(%cond: i1, %arg0: i32) -> i32 {
// CHECK: a:
// CHECK: op_preds: (all) predecessors:
// CHECK: func.return {a}
%0 = func.call @foo(%arg0) {tag = "a"} : (i32) -> i32
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
// CHECK: b:
// CHECK: op_preds: (all) predecessors:
// CHECK: func.return {a}
%1 = func.call @foo(%arg0) {tag = "b"} : (i32) -> i32
return %1 : i32
^bb2:
// CHECK: c:
// CHECK: op_preds: (all) predecessors:
// CHECK: func.return {b}
// CHECK: func.return {c}
%2 = func.call @bar(%cond) {tag = "c"} : (i1) -> i32
// CHECK: d:
// CHECK: op_preds: predecessors:
%3 = func.call @baz(%arg0) {tag = "d"} : (i32) -> i32
return %2 : i32
}
// CHECK: test_unknown_branch:
// CHECK: region #0
// CHECK: ^bb0 = live
// CHECK: ^bb1 = live
// CHECK: from ^bb0 = live
// CHECK: ^bb2 = live
// CHECK: from ^bb0 = live
func.func @test_unknown_branch() -> ()
attributes {tag = "test_unknown_branch"} {
"test.unknown_br"() [^bb1, ^bb2] : () -> ()
^bb1:
return
^bb2:
return
}
// CHECK: test_unknown_region:
// CHECK: region #0
// CHECK: ^bb0 = live
// CHECK: region #1
// CHECK: ^bb0 = live
func.func @test_unknown_region() -> () {
"test.unknown_region_br"() ({
^bb0:
"test.unknown_region_end"() : () -> ()
}, {
^bb0:
"test.unknown_region_end"() : () -> ()
}) {tag = "test_unknown_region"} : () -> ()
return
}
// CHECK: test_known_dead_block:
// CHECK: region #0
// CHECK: ^bb0 = live
// CHECK: ^bb1 = live
// CHECK: ^bb2 = dead
func.func @test_known_dead_block() -> ()
attributes {tag = "test_known_dead_block"} {
%true = arith.constant true
cf.cond_br %true, ^bb1, ^bb2
^bb1:
return
^bb2:
return
}
// CHECK: test_known_dead_edge:
// CHECK: ^bb2 = live
// CHECK: from ^bb1 = dead
// CHECK: from ^bb0 = live
func.func @test_known_dead_edge(%arg0: i1) -> ()
attributes {tag = "test_known_dead_edge"} {
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
%true = arith.constant true
cf.cond_br %true, ^bb3, ^bb2
^bb2:
return
^bb3:
return
}
func.func @test_known_region_predecessors() -> () {
%false = arith.constant false
// CHECK: test_known_if:
// CHECK: region #0
// CHECK: ^bb0 = dead
// CHECK: region #1
// CHECK: ^bb0 = live
// CHECK: region_preds: (all) predecessors:
// CHECK: scf.if
// CHECK: op_preds: (all) predecessors:
// CHECK: scf.yield {else}
scf.if %false {
scf.yield {then}
} else {
scf.yield {else}
} {tag = "test_known_if"}
return
}
// CHECK: callable:
// CHECK: region #0
// CHECK: ^bb0 = live
// CHECK: op_preds: predecessors:
// CHECK: func.call @callable() {then}
func.func @callable() attributes {tag = "callable"} {
return
}
func.func @test_dead_callsite() -> () {
%true = arith.constant true
scf.if %true {
func.call @callable() {then} : () -> ()
scf.yield
} else {
func.call @callable() {else} : () -> ()
scf.yield
}
return
}
func.func private @test_dead_return(%arg0: i32) -> i32 {
%true = arith.constant true
cf.cond_br %true, ^bb1, ^bb1
^bb1:
return {true} %arg0 : i32
^bb2:
return {false} %arg0 : i32
}
func.func @test_call_dead_return(%arg0: i32) -> () {
// CHECK: test_dead_return:
// CHECK: op_preds: (all) predecessors:
// CHECK: func.return {true}
%0 = func.call @test_dead_return(%arg0) {tag = "test_dead_return"} : (i32) -> i32
return
}

View File

@ -11,6 +11,7 @@ add_mlir_library(MLIRTestAnalysis
TestMemRefStrideCalculation.cpp
TestSlice.cpp
DataFlow/TestDeadCodeAnalysis.cpp
EXCLUDE_FROM_LIBMLIR

View File

@ -0,0 +1,118 @@
//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::dataflow;
/// Print the liveness of every block, control-flow edge, and the predecessors
/// of all regions, callables, and calls.
static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
raw_ostream &os) {
op->walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>("tag");
if (!tag)
return;
os << tag.getValue() << ":\n";
for (Region &region : op->getRegions()) {
os << " region #" << region.getRegionNumber() << "\n";
for (Block &block : region) {
os << " ";
block.printAsOperand(os);
os << " = ";
auto *live = solver.lookupState<Executable>(&block);
if (live)
os << *live;
else
os << "dead";
os << "\n";
for (Block *pred : block.getPredecessors()) {
os << " from ";
pred->printAsOperand(os);
os << " = ";
auto *live = solver.lookupState<Executable>(
solver.getProgramPoint<CFGEdge>(pred, &block));
if (live)
os << *live;
else
os << "dead";
os << "\n";
}
}
if (!region.empty()) {
auto *preds = solver.lookupState<PredecessorState>(&region.front());
if (preds)
os << "region_preds: " << *preds << "\n";
}
}
auto *preds = solver.lookupState<PredecessorState>(op);
if (preds)
os << "op_preds: " << *preds << "\n";
});
}
namespace {
/// This is a simple analysis that implements a transfer function for constant
/// operations.
struct ConstantAnalysis : public DataFlowAnalysis {
using DataFlowAnalysis::DataFlowAnalysis;
LogicalResult initialize(Operation *top) override {
WalkResult result = top->walk([&](Operation *op) {
if (op->hasTrait<OpTrait::ConstantLike>())
if (failed(visit(op)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!result.wasInterrupted());
}
LogicalResult visit(ProgramPoint point) override {
Operation *op = point.get<Operation *>();
Attribute value;
if (matchPattern(op, m_Constant(&value))) {
auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
propagateIfChanged(
constant, constant->join(ConstantValue(value, op->getDialect())));
}
return success();
}
};
/// This is a simple pass that runs dead code analysis with no constant value
/// provider. It marks everything as live.
struct TestDeadCodeAnalysisPass
: public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
StringRef getArgument() const override { return "test-dead-code-analysis"; }
void runOnOperation() override {
Operation *op = getOperation();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<ConstantAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
printAnalysisResults(solver, op, llvm::errs());
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestDeadCodeAnalysisPass() {
PassRegistration<TestDeadCodeAnalysisPass>();
}
} // end namespace test
} // end namespace mlir

View File

@ -72,6 +72,7 @@ void registerTestGpuSerializeToCubinPass();
void registerTestGpuSerializeToHsacoPass();
void registerTestDataFlowPass();
void registerTestDataLayoutQuery();
void registerTestDeadCodeAnalysisPass();
void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass();
void registerTestDominancePass();
@ -173,6 +174,7 @@ void registerTestPasses() {
mlir::test::registerTestDecomposeCallGraphTypes();
mlir::test::registerTestDataFlowPass();
mlir::test::registerTestDataLayoutQuery();
mlir::test::registerTestDeadCodeAnalysisPass();
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
mlir::test::registerTestExpandMathPass();