2020-04-21 17:54:05 +08:00
|
|
|
//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This transformation pass performs a sparse conditional constant propagation
|
|
|
|
// in MLIR. It identifies values known to be constant, propagates that
|
|
|
|
// information throughout the IR, and replaces them. This is done with an
|
2020-10-29 03:03:15 +08:00
|
|
|
// optimistic dataflow analysis that assumes that all values are constant until
|
2020-04-21 17:54:05 +08:00
|
|
|
// proven otherwise.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
2021-04-27 10:33:50 +08:00
|
|
|
#include "mlir/Analysis/DataFlowAnalysis.h"
|
2020-04-21 17:54:05 +08:00
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/Dialect.h"
|
|
|
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
2020-05-14 01:27:19 +08:00
|
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
2020-04-21 17:54:05 +08:00
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
|
|
#include "mlir/Transforms/Passes.h"
|
2022-01-05 16:44:27 +08:00
|
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
|
|
|
|
#define DEBUG_TYPE "sccp"
|
2020-04-21 17:54:05 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SCCP Analysis
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-21 17:54:05 +08:00
|
|
|
namespace {
|
2021-04-27 10:33:50 +08:00
|
|
|
struct SCCPLatticeValue {
|
|
|
|
SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
|
|
|
|
: constant(constant), constantDialect(dialect) {}
|
2020-04-21 17:54:05 +08:00
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
/// The pessimistic state of SCCP is non-constant.
|
|
|
|
static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
|
|
|
|
return SCCPLatticeValue();
|
2020-04-21 17:54:05 +08:00
|
|
|
}
|
2021-04-27 10:33:50 +08:00
|
|
|
static SCCPLatticeValue getPessimisticValueState(Value value) {
|
|
|
|
return SCCPLatticeValue();
|
2020-04-21 17:54:05 +08:00
|
|
|
}
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
/// Equivalence for SCCP only accounts for the constant, not the originating
|
|
|
|
/// dialect.
|
|
|
|
bool operator==(const SCCPLatticeValue &rhs) const {
|
|
|
|
return constant == rhs.constant;
|
2020-04-21 17:54:05 +08:00
|
|
|
}
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
/// To join the state of two values, we simply check for equivalence.
|
|
|
|
static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
|
|
|
|
const SCCPLatticeValue &rhs) {
|
|
|
|
return lhs == rhs ? lhs : SCCPLatticeValue();
|
2020-04-21 17:54:05 +08:00
|
|
|
}
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
/// The constant attribute value.
|
|
|
|
Attribute constant;
|
2020-04-21 17:54:05 +08:00
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
/// The dialect the constant originated from. This is not used as part of the
|
|
|
|
/// key, and is only needed to materialize the held constant if necessary.
|
2020-04-21 17:54:05 +08:00
|
|
|
Dialect *constantDialect;
|
|
|
|
};
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
|
|
|
|
using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis;
|
|
|
|
~SCCPAnalysis() override = default;
|
2020-04-28 03:58:23 +08:00
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
ChangeResult
|
|
|
|
visitOperation(Operation *op,
|
|
|
|
ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
|
2022-01-05 16:44:27 +08:00
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "SCCP: Visiting operation: " << *op << "\n");
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
// Don't try to simulate the results of a region operation as we can't
|
|
|
|
// guarantee that folding will be out-of-place. We don't allow in-place
|
|
|
|
// folds as the desire here is for simulated execution, and not general
|
|
|
|
// folding.
|
|
|
|
if (op->getNumRegions())
|
|
|
|
return markAllPessimisticFixpoint(op->getResults());
|
|
|
|
|
|
|
|
SmallVector<Attribute> constantOperands(
|
|
|
|
llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
|
|
|
|
return value->getValue().constant;
|
|
|
|
}));
|
|
|
|
|
|
|
|
// Save the original operands and attributes just in case the operation
|
|
|
|
// folds in-place. The constant passed in may not correspond to the real
|
|
|
|
// runtime value, so in-place updates are not allowed.
|
|
|
|
SmallVector<Value, 8> originalOperands(op->getOperands());
|
|
|
|
DictionaryAttr originalAttrs = op->getAttrDictionary();
|
|
|
|
|
|
|
|
// Simulate the result of folding this operation to a constant. If folding
|
|
|
|
// fails or was an in-place fold, mark the results as overdefined.
|
|
|
|
SmallVector<OpFoldResult, 8> foldResults;
|
|
|
|
foldResults.reserve(op->getNumResults());
|
|
|
|
if (failed(op->fold(constantOperands, foldResults)))
|
|
|
|
return markAllPessimisticFixpoint(op->getResults());
|
|
|
|
|
|
|
|
// If the folding was in-place, mark the results as overdefined and reset
|
|
|
|
// the operation. We don't allow in-place folds as the desire here is for
|
|
|
|
// simulated execution, and not general folding.
|
|
|
|
if (foldResults.empty()) {
|
|
|
|
op->setOperands(originalOperands);
|
|
|
|
op->setAttrs(originalAttrs);
|
|
|
|
return markAllPessimisticFixpoint(op->getResults());
|
|
|
|
}
|
2020-04-21 17:54:14 +08:00
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
// Merge the fold results into the lattice for this operation.
|
|
|
|
assert(foldResults.size() == op->getNumResults() && "invalid result size");
|
|
|
|
Dialect *dialect = op->getDialect();
|
|
|
|
ChangeResult result = ChangeResult::NoChange;
|
|
|
|
for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
|
|
|
|
LatticeElement<SCCPLatticeValue> &lattice =
|
|
|
|
getLatticeElement(op->getResult(i));
|
|
|
|
|
|
|
|
// Merge in the result of the fold, either a constant or a value.
|
|
|
|
OpFoldResult foldResult = foldResults[i];
|
|
|
|
if (Attribute attr = foldResult.dyn_cast<Attribute>())
|
|
|
|
result |= lattice.join(SCCPLatticeValue(attr, dialect));
|
|
|
|
else
|
|
|
|
result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
|
|
|
|
}
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Implementation of `getSuccessorsForOperands` that uses constant operands
|
|
|
|
/// to potentially remove dead successors.
|
|
|
|
LogicalResult getSuccessorsForOperands(
|
|
|
|
BranchOpInterface branch,
|
|
|
|
ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
|
|
|
|
SmallVectorImpl<Block *> &successors) final {
|
|
|
|
SmallVector<Attribute> constantOperands(
|
|
|
|
llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
|
|
|
|
return value->getValue().constant;
|
|
|
|
}));
|
|
|
|
if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
|
|
|
|
successors.push_back(singleSucc);
|
|
|
|
return success();
|
2020-04-21 17:54:14 +08:00
|
|
|
}
|
2021-04-27 10:33:50 +08:00
|
|
|
return failure();
|
2020-04-21 17:54:14 +08:00
|
|
|
}
|
2020-04-21 17:54:05 +08:00
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
/// Implementation of `getSuccessorsForOperands` that uses constant operands
|
|
|
|
/// to potentially remove dead region successors.
|
|
|
|
void getSuccessorsForOperands(
|
|
|
|
RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
|
|
|
|
ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
|
|
|
|
SmallVectorImpl<RegionSuccessor> &successors) final {
|
|
|
|
SmallVector<Attribute> constantOperands(
|
|
|
|
llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
|
|
|
|
return value->getValue().constant;
|
|
|
|
}));
|
|
|
|
branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
|
|
|
|
}
|
2020-04-21 17:54:05 +08:00
|
|
|
};
|
2021-04-27 10:33:50 +08:00
|
|
|
} // namespace
|
2020-04-21 17:54:05 +08:00
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SCCP Rewrites
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-04-21 17:54:05 +08:00
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
/// Replace the given value with a constant if the corresponding lattice
|
|
|
|
/// represents a constant. Returns success if the value was replaced, failure
|
|
|
|
/// otherwise.
|
|
|
|
static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
|
|
|
|
OpBuilder &builder,
|
|
|
|
OperationFolder &folder, Value value) {
|
|
|
|
LatticeElement<SCCPLatticeValue> *lattice =
|
|
|
|
analysis.lookupLatticeElement(value);
|
|
|
|
if (!lattice)
|
|
|
|
return failure();
|
|
|
|
SCCPLatticeValue &latticeValue = lattice->getValue();
|
|
|
|
if (!latticeValue.constant)
|
|
|
|
return failure();
|
2020-04-21 17:54:05 +08:00
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
// Attempt to materialize a constant for the given value.
|
|
|
|
Dialect *dialect = latticeValue.constantDialect;
|
|
|
|
Value constant = folder.getOrCreateConstant(
|
|
|
|
builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
|
|
|
|
if (!constant)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
value.replaceAllUsesWith(constant);
|
|
|
|
return success();
|
2020-04-21 17:54:05 +08:00
|
|
|
}
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
/// Rewrite the given regions using the computing analysis. This replaces the
|
|
|
|
/// uses of all values that have been computed to be constant, and erases as
|
|
|
|
/// many newly dead operations.
|
|
|
|
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
|
|
|
|
MutableArrayRef<Region> initialRegions) {
|
|
|
|
SmallVector<Block *> worklist;
|
2020-04-21 17:54:05 +08:00
|
|
|
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
|
|
|
|
for (Region ®ion : regions)
|
2021-04-27 10:33:50 +08:00
|
|
|
for (Block &block : llvm::reverse(region))
|
|
|
|
worklist.push_back(&block);
|
2020-04-21 17:54:05 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
// An operation folder used to create and unique constants.
|
|
|
|
OperationFolder folder(context);
|
|
|
|
OpBuilder builder(context);
|
|
|
|
|
|
|
|
addToWorklist(initialRegions);
|
|
|
|
while (!worklist.empty()) {
|
|
|
|
Block *block = worklist.pop_back_val();
|
|
|
|
|
|
|
|
for (Operation &op : llvm::make_early_inc_range(*block)) {
|
|
|
|
builder.setInsertionPoint(&op);
|
|
|
|
|
|
|
|
// Replace any result with constants.
|
|
|
|
bool replacedAll = op.getNumResults() != 0;
|
|
|
|
for (Value res : op.getResults())
|
2021-04-27 10:33:50 +08:00
|
|
|
replacedAll &=
|
|
|
|
succeeded(replaceWithConstant(analysis, builder, folder, res));
|
2020-04-21 17:54:05 +08:00
|
|
|
|
|
|
|
// If all of the results of the operation were replaced, try to erase
|
|
|
|
// the operation completely.
|
|
|
|
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
|
|
|
|
assert(op.use_empty() && "expected all uses to be replaced");
|
|
|
|
op.erase();
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add any the regions of this operation to the worklist.
|
|
|
|
addToWorklist(op.getRegions());
|
|
|
|
}
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
// Replace any block arguments with constants.
|
|
|
|
builder.setInsertionPointToStart(block);
|
|
|
|
for (BlockArgument arg : block->getArguments())
|
|
|
|
(void)replaceWithConstant(analysis, builder, folder, arg);
|
2020-04-21 17:54:05 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SCCP Pass
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
struct SCCP : public SCCPBase<SCCP> {
|
|
|
|
void runOnOperation() override;
|
|
|
|
};
|
2021-12-08 02:27:58 +08:00
|
|
|
} // namespace
|
2020-04-21 17:54:05 +08:00
|
|
|
|
|
|
|
void SCCP::runOnOperation() {
|
|
|
|
Operation *op = getOperation();
|
|
|
|
|
2021-04-27 10:33:50 +08:00
|
|
|
SCCPAnalysis analysis(op->getContext());
|
|
|
|
analysis.run(op);
|
|
|
|
rewrite(analysis, op->getContext(), op->getRegions());
|
2020-04-21 17:54:05 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<Pass> mlir::createSCCPPass() {
|
|
|
|
return std::make_unique<SCCP>();
|
|
|
|
}
|