forked from OSchip/llvm-project
NFC: Rename FoldHelper to OperationFolder and split a large function in two.
PiperOrigin-RevId: 251485843
This commit is contained in:
parent
9fc4193eea
commit
9b4a02c1e9
|
@ -31,14 +31,15 @@
|
|||
namespace mlir {
|
||||
class Function;
|
||||
class Operation;
|
||||
class Value;
|
||||
|
||||
/// A helper class for folding operations, and unifying duplicated constants
|
||||
/// A utility class for folding operations, and unifying duplicated constants
|
||||
/// generated along the way.
|
||||
///
|
||||
/// To make sure constants properly dominate all their uses, constants are
|
||||
/// moved to the beginning of the entry block of the function when tracked by
|
||||
/// this class.
|
||||
class FoldHelper {
|
||||
class OperationFolder {
|
||||
public:
|
||||
/// Constructs an instance for managing constants in the given function `f`.
|
||||
/// Constants tracked by this instance will be moved to the entry block of
|
||||
|
@ -47,7 +48,7 @@ public:
|
|||
/// This instance does not proactively walk the operations inside `f`;
|
||||
/// instead, users must invoke the following methods to manually handle each
|
||||
/// operation of interest.
|
||||
FoldHelper(Function *f);
|
||||
OperationFolder(Function *f) : function(f) {}
|
||||
|
||||
/// Tries to perform folding on the given `op`, including unifying
|
||||
/// deduplicated constants. If successful, calls `preReplaceAction` (if
|
||||
|
@ -59,13 +60,17 @@ public:
|
|||
std::function<void(Operation *)> preReplaceAction = {});
|
||||
|
||||
/// Notifies that the given constant `op` should be remove from this
|
||||
/// FoldHelper's internal bookkeeping.
|
||||
/// OperationFolder's internal bookkeeping.
|
||||
///
|
||||
/// Note: this method must be called if a constant op is to be deleted
|
||||
/// externally to this FoldHelper. `op` must be a constant op.
|
||||
/// externally to this OperationFolder. `op` must be a constant op.
|
||||
void notifyRemoval(Operation *op);
|
||||
|
||||
private:
|
||||
/// Tries to perform folding on the given `op`. If successful, populates
|
||||
/// `results` with the results of the foldin.
|
||||
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results);
|
||||
|
||||
/// Tries to deduplicate the given constant and returns success if that can be
|
||||
/// done. This moves the given constant to the top of the entry block if it
|
||||
/// is first seen. If there is already an existing constant that is the same,
|
||||
|
|
|
@ -32,12 +32,12 @@ struct TestConstantFold : public FunctionPass<TestConstantFold> {
|
|||
// All constants in the function post folding.
|
||||
SmallVector<Operation *, 8> existingConstants;
|
||||
|
||||
void foldOperation(Operation *op, FoldHelper &helper);
|
||||
void foldOperation(Operation *op, OperationFolder &helper);
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void TestConstantFold::foldOperation(Operation *op, FoldHelper &helper) {
|
||||
void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
|
||||
// Attempt to fold the specified operation, including handling unused or
|
||||
// duplicated constants.
|
||||
if (succeeded(helper.tryToFold(op)))
|
||||
|
@ -56,7 +56,7 @@ void TestConstantFold::runOnFunction() {
|
|||
existingConstants.clear();
|
||||
|
||||
auto &f = getFunction();
|
||||
FoldHelper helper(&f);
|
||||
OperationFolder helper(&f);
|
||||
|
||||
// Collect and fold the operations within the function.
|
||||
SmallVector<Operation *, 8> ops;
|
||||
|
|
|
@ -29,10 +29,12 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
FoldHelper::FoldHelper(Function *f) : function(f) {}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperationFolder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
FoldHelper::tryToFold(Operation *op,
|
||||
OperationFolder::tryToFold(Operation *op,
|
||||
std::function<void(Operation *)> preReplaceAction) {
|
||||
assert(op->getFunction() == function &&
|
||||
"cannot constant fold op from another function");
|
||||
|
@ -52,8 +54,37 @@ FoldHelper::tryToFold(Operation *op,
|
|||
return tryToUnify(op);
|
||||
}
|
||||
|
||||
// Try to fold the operation.
|
||||
SmallVector<Value *, 8> results;
|
||||
if (failed(tryToFold(op, results)))
|
||||
return failure();
|
||||
|
||||
// Constant folding succeeded. We will start replacing this op's uses and
|
||||
// eventually erase this op. Invoke the callback provided by the caller to
|
||||
// perform any pre-replacement action.
|
||||
if (preReplaceAction)
|
||||
preReplaceAction(op);
|
||||
|
||||
// Check to see if the operation was just updated in place.
|
||||
if (results.empty())
|
||||
return success();
|
||||
|
||||
// Otherwise, replace all of the result values and erase the operation.
|
||||
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
||||
op->getResult(i)->replaceAllUsesWith(results[i]);
|
||||
op->erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Tries to perform folding on the given `op`. If successful, populates
|
||||
/// `results` with the results of the foldin.
|
||||
LogicalResult OperationFolder::tryToFold(Operation *op,
|
||||
SmallVectorImpl<Value *> &results) {
|
||||
assert(op->getFunction() == function &&
|
||||
"cannot constant fold op from another function");
|
||||
|
||||
SmallVector<Attribute, 8> operandConstants;
|
||||
SmallVector<OpFoldResult, 8> results;
|
||||
SmallVector<OpFoldResult, 8> foldResults;
|
||||
|
||||
// Check to see if any operands to the operation is constant and whether
|
||||
// the operation knows how to constant fold itself.
|
||||
|
@ -70,38 +101,29 @@ FoldHelper::tryToFold(Operation *op,
|
|||
}
|
||||
|
||||
// Attempt to constant fold the operation.
|
||||
if (failed(op->fold(operandConstants, results)))
|
||||
if (failed(op->fold(operandConstants, foldResults)))
|
||||
return failure();
|
||||
|
||||
// Constant folding succeeded. We will start replacing this op's uses and
|
||||
// eventually erase this op. Invoke the callback provided by the caller to
|
||||
// perform any pre-replacement action.
|
||||
if (preReplaceAction)
|
||||
preReplaceAction(op);
|
||||
|
||||
// Check to see if the operation was just updated in place.
|
||||
if (results.empty())
|
||||
if (foldResults.empty())
|
||||
return success();
|
||||
assert(results.size() == op->getNumResults());
|
||||
assert(foldResults.size() == op->getNumResults());
|
||||
|
||||
// Create the result constants and replace the results.
|
||||
FuncBuilder builder(op);
|
||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
auto *res = op->getResult(i);
|
||||
if (res->use_empty()) // Ignore dead uses.
|
||||
continue;
|
||||
assert(!results[i].isNull() && "expected valid OpFoldResult");
|
||||
assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
|
||||
|
||||
// Check if the result was an SSA value.
|
||||
if (auto *repl = results[i].dyn_cast<Value *>()) {
|
||||
if (repl != res)
|
||||
res->replaceAllUsesWith(repl);
|
||||
if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
|
||||
results.emplace_back(repl);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we already have a canonicalized version of this constant, just reuse
|
||||
// it. Otherwise create a new one.
|
||||
Attribute attrRepl = results[i].get<Attribute>();
|
||||
Attribute attrRepl = foldResults[i].get<Attribute>();
|
||||
auto *res = op->getResult(i);
|
||||
auto &constInst =
|
||||
uniquedConstants[std::make_pair(attrRepl, res->getType())];
|
||||
if (!constInst) {
|
||||
|
@ -113,14 +135,13 @@ FoldHelper::tryToFold(Operation *op,
|
|||
constInst = newOp.getOperation();
|
||||
moveConstantToEntryBlock(constInst);
|
||||
}
|
||||
res->replaceAllUsesWith(constInst->getResult(0));
|
||||
results.push_back(constInst->getResult(0));
|
||||
}
|
||||
op->erase();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void FoldHelper::notifyRemoval(Operation *op) {
|
||||
void OperationFolder::notifyRemoval(Operation *op) {
|
||||
assert(op->getFunction() == function &&
|
||||
"cannot remove constant from another function");
|
||||
|
||||
|
@ -134,7 +155,7 @@ void FoldHelper::notifyRemoval(Operation *op) {
|
|||
uniquedConstants.erase(it);
|
||||
}
|
||||
|
||||
LogicalResult FoldHelper::tryToUnify(Operation *op) {
|
||||
LogicalResult OperationFolder::tryToUnify(Operation *op) {
|
||||
Attribute constValue;
|
||||
matchPattern(op, m_Constant(&constValue));
|
||||
assert(constValue);
|
||||
|
@ -163,7 +184,7 @@ LogicalResult FoldHelper::tryToUnify(Operation *op) {
|
|||
return failure();
|
||||
}
|
||||
|
||||
void FoldHelper::moveConstantToEntryBlock(Operation *op) {
|
||||
void OperationFolder::moveConstantToEntryBlock(Operation *op) {
|
||||
// Insert at the very top of the entry block.
|
||||
auto &entryBB = function->front();
|
||||
op->moveBefore(&entryBB, entryBB.begin());
|
||||
|
|
|
@ -143,7 +143,7 @@ private:
|
|||
/// Perform the rewrites.
|
||||
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
||||
Function *fn = getFunction();
|
||||
FoldHelper helper(fn);
|
||||
OperationFolder helper(fn);
|
||||
|
||||
bool changed = false;
|
||||
int i = 0;
|
||||
|
@ -166,8 +166,8 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
|||
// If the operation has no side effects, and no users, then it is
|
||||
// trivially dead - remove it.
|
||||
if (op->hasNoSideEffect() && op->use_empty()) {
|
||||
// Be careful to update bookkeeping in FoldHelper to keep consistency if
|
||||
// this is a constant op.
|
||||
// Be careful to update bookkeeping in OperationFolder to keep
|
||||
// consistency if this is a constant op.
|
||||
helper.notifyRemoval(op);
|
||||
op->erase();
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue