NFC: Rename FoldHelper to OperationFolder and split a large function in two.

PiperOrigin-RevId: 251485843
This commit is contained in:
River Riddle 2019-06-04 11:56:43 -07:00 committed by Mehdi Amini
parent 9fc4193eea
commit 9b4a02c1e9
4 changed files with 64 additions and 38 deletions

View File

@ -31,14 +31,15 @@
namespace mlir { namespace mlir {
class Function; class Function;
class Operation; class Operation;
class Value;
/// A helper class for folding operations, and unifying duplicated constants /// A utility class for folding operations, and unifying duplicated constants
/// generated along the way. /// generated along the way.
/// ///
/// To make sure constants properly dominate all their uses, constants are /// To make sure constants properly dominate all their uses, constants are
/// moved to the beginning of the entry block of the function when tracked by /// moved to the beginning of the entry block of the function when tracked by
/// this class. /// this class.
class FoldHelper { class OperationFolder {
public: public:
/// Constructs an instance for managing constants in the given function `f`. /// Constructs an instance for managing constants in the given function `f`.
/// Constants tracked by this instance will be moved to the entry block of /// Constants tracked by this instance will be moved to the entry block of
@ -47,7 +48,7 @@ public:
/// This instance does not proactively walk the operations inside `f`; /// This instance does not proactively walk the operations inside `f`;
/// instead, users must invoke the following methods to manually handle each /// instead, users must invoke the following methods to manually handle each
/// operation of interest. /// operation of interest.
FoldHelper(Function *f); OperationFolder(Function *f) : function(f) {}
/// Tries to perform folding on the given `op`, including unifying /// Tries to perform folding on the given `op`, including unifying
/// deduplicated constants. If successful, calls `preReplaceAction` (if /// deduplicated constants. If successful, calls `preReplaceAction` (if
@ -59,13 +60,17 @@ public:
std::function<void(Operation *)> preReplaceAction = {}); std::function<void(Operation *)> preReplaceAction = {});
/// Notifies that the given constant `op` should be remove from this /// Notifies that the given constant `op` should be remove from this
/// FoldHelper's internal bookkeeping. /// OperationFolder's internal bookkeeping.
/// ///
/// Note: this method must be called if a constant op is to be deleted /// Note: this method must be called if a constant op is to be deleted
/// externally to this FoldHelper. `op` must be a constant op. /// externally to this OperationFolder. `op` must be a constant op.
void notifyRemoval(Operation *op); void notifyRemoval(Operation *op);
private: private:
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the foldin.
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results);
/// Tries to deduplicate the given constant and returns success if that can be /// Tries to deduplicate the given constant and returns success if that can be
/// done. This moves the given constant to the top of the entry block if it /// done. This moves the given constant to the top of the entry block if it
/// is first seen. If there is already an existing constant that is the same, /// is first seen. If there is already an existing constant that is the same,

View File

@ -32,12 +32,12 @@ struct TestConstantFold : public FunctionPass<TestConstantFold> {
// All constants in the function post folding. // All constants in the function post folding.
SmallVector<Operation *, 8> existingConstants; SmallVector<Operation *, 8> existingConstants;
void foldOperation(Operation *op, FoldHelper &helper); void foldOperation(Operation *op, OperationFolder &helper);
void runOnFunction() override; void runOnFunction() override;
}; };
} // end anonymous namespace } // end anonymous namespace
void TestConstantFold::foldOperation(Operation *op, FoldHelper &helper) { void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
// Attempt to fold the specified operation, including handling unused or // Attempt to fold the specified operation, including handling unused or
// duplicated constants. // duplicated constants.
if (succeeded(helper.tryToFold(op))) if (succeeded(helper.tryToFold(op)))
@ -56,7 +56,7 @@ void TestConstantFold::runOnFunction() {
existingConstants.clear(); existingConstants.clear();
auto &f = getFunction(); auto &f = getFunction();
FoldHelper helper(&f); OperationFolder helper(&f);
// Collect and fold the operations within the function. // Collect and fold the operations within the function.
SmallVector<Operation *, 8> ops; SmallVector<Operation *, 8> ops;

View File

@ -29,11 +29,13 @@
using namespace mlir; using namespace mlir;
FoldHelper::FoldHelper(Function *f) : function(f) {} //===----------------------------------------------------------------------===//
// OperationFolder
//===----------------------------------------------------------------------===//
LogicalResult LogicalResult
FoldHelper::tryToFold(Operation *op, OperationFolder::tryToFold(Operation *op,
std::function<void(Operation *)> preReplaceAction) { std::function<void(Operation *)> preReplaceAction) {
assert(op->getFunction() == function && assert(op->getFunction() == function &&
"cannot constant fold op from another function"); "cannot constant fold op from another function");
@ -52,8 +54,37 @@ FoldHelper::tryToFold(Operation *op,
return tryToUnify(op); return tryToUnify(op);
} }
// Try to fold the operation.
SmallVector<Value *, 8> results;
if (failed(tryToFold(op, results)))
return failure();
// Constant folding succeeded. We will start replacing this op's uses and
// eventually erase this op. Invoke the callback provided by the caller to
// perform any pre-replacement action.
if (preReplaceAction)
preReplaceAction(op);
// Check to see if the operation was just updated in place.
if (results.empty())
return success();
// Otherwise, replace all of the result values and erase the operation.
for (unsigned i = 0, e = results.size(); i != e; ++i)
op->getResult(i)->replaceAllUsesWith(results[i]);
op->erase();
return success();
}
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the foldin.
LogicalResult OperationFolder::tryToFold(Operation *op,
SmallVectorImpl<Value *> &results) {
assert(op->getFunction() == function &&
"cannot constant fold op from another function");
SmallVector<Attribute, 8> operandConstants; SmallVector<Attribute, 8> operandConstants;
SmallVector<OpFoldResult, 8> results; SmallVector<OpFoldResult, 8> foldResults;
// Check to see if any operands to the operation is constant and whether // Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself. // the operation knows how to constant fold itself.
@ -70,38 +101,29 @@ FoldHelper::tryToFold(Operation *op,
} }
// Attempt to constant fold the operation. // Attempt to constant fold the operation.
if (failed(op->fold(operandConstants, results))) if (failed(op->fold(operandConstants, foldResults)))
return failure(); return failure();
// Constant folding succeeded. We will start replacing this op's uses and
// eventually erase this op. Invoke the callback provided by the caller to
// perform any pre-replacement action.
if (preReplaceAction)
preReplaceAction(op);
// Check to see if the operation was just updated in place. // Check to see if the operation was just updated in place.
if (results.empty()) if (foldResults.empty())
return success(); return success();
assert(results.size() == op->getNumResults()); assert(foldResults.size() == op->getNumResults());
// Create the result constants and replace the results. // Create the result constants and replace the results.
FuncBuilder builder(op); FuncBuilder builder(op);
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
auto *res = op->getResult(i); assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
if (res->use_empty()) // Ignore dead uses.
continue;
assert(!results[i].isNull() && "expected valid OpFoldResult");
// Check if the result was an SSA value. // Check if the result was an SSA value.
if (auto *repl = results[i].dyn_cast<Value *>()) { if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
if (repl != res) results.emplace_back(repl);
res->replaceAllUsesWith(repl);
continue; continue;
} }
// If we already have a canonicalized version of this constant, just reuse // If we already have a canonicalized version of this constant, just reuse
// it. Otherwise create a new one. // it. Otherwise create a new one.
Attribute attrRepl = results[i].get<Attribute>(); Attribute attrRepl = foldResults[i].get<Attribute>();
auto *res = op->getResult(i);
auto &constInst = auto &constInst =
uniquedConstants[std::make_pair(attrRepl, res->getType())]; uniquedConstants[std::make_pair(attrRepl, res->getType())];
if (!constInst) { if (!constInst) {
@ -113,14 +135,13 @@ FoldHelper::tryToFold(Operation *op,
constInst = newOp.getOperation(); constInst = newOp.getOperation();
moveConstantToEntryBlock(constInst); moveConstantToEntryBlock(constInst);
} }
res->replaceAllUsesWith(constInst->getResult(0)); results.push_back(constInst->getResult(0));
} }
op->erase();
return success(); return success();
} }
void FoldHelper::notifyRemoval(Operation *op) { void OperationFolder::notifyRemoval(Operation *op) {
assert(op->getFunction() == function && assert(op->getFunction() == function &&
"cannot remove constant from another function"); "cannot remove constant from another function");
@ -134,7 +155,7 @@ void FoldHelper::notifyRemoval(Operation *op) {
uniquedConstants.erase(it); uniquedConstants.erase(it);
} }
LogicalResult FoldHelper::tryToUnify(Operation *op) { LogicalResult OperationFolder::tryToUnify(Operation *op) {
Attribute constValue; Attribute constValue;
matchPattern(op, m_Constant(&constValue)); matchPattern(op, m_Constant(&constValue));
assert(constValue); assert(constValue);
@ -163,7 +184,7 @@ LogicalResult FoldHelper::tryToUnify(Operation *op) {
return failure(); return failure();
} }
void FoldHelper::moveConstantToEntryBlock(Operation *op) { void OperationFolder::moveConstantToEntryBlock(Operation *op) {
// Insert at the very top of the entry block. // Insert at the very top of the entry block.
auto &entryBB = function->front(); auto &entryBB = function->front();
op->moveBefore(&entryBB, entryBB.begin()); op->moveBefore(&entryBB, entryBB.begin());

View File

@ -143,7 +143,7 @@ private:
/// Perform the rewrites. /// Perform the rewrites.
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
Function *fn = getFunction(); Function *fn = getFunction();
FoldHelper helper(fn); OperationFolder helper(fn);
bool changed = false; bool changed = false;
int i = 0; int i = 0;
@ -166,8 +166,8 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
// If the operation has no side effects, and no users, then it is // If the operation has no side effects, and no users, then it is
// trivially dead - remove it. // trivially dead - remove it.
if (op->hasNoSideEffect() && op->use_empty()) { if (op->hasNoSideEffect() && op->use_empty()) {
// Be careful to update bookkeeping in FoldHelper to keep consistency if // Be careful to update bookkeeping in OperationFolder to keep
// this is a constant op. // consistency if this is a constant op.
helper.notifyRemoval(op); helper.notifyRemoval(op);
op->erase(); op->erase();
continue; continue;