forked from OSchip/llvm-project
NFC: Rename FoldHelper to OperationFolder and split a large function in two.
PiperOrigin-RevId: 251485843
This commit is contained in:
parent
9fc4193eea
commit
9b4a02c1e9
|
@ -31,14 +31,15 @@
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class Function;
|
class Function;
|
||||||
class Operation;
|
class Operation;
|
||||||
|
class Value;
|
||||||
|
|
||||||
/// A helper class for folding operations, and unifying duplicated constants
|
/// A utility class for folding operations, and unifying duplicated constants
|
||||||
/// generated along the way.
|
/// generated along the way.
|
||||||
///
|
///
|
||||||
/// To make sure constants properly dominate all their uses, constants are
|
/// To make sure constants properly dominate all their uses, constants are
|
||||||
/// moved to the beginning of the entry block of the function when tracked by
|
/// moved to the beginning of the entry block of the function when tracked by
|
||||||
/// this class.
|
/// this class.
|
||||||
class FoldHelper {
|
class OperationFolder {
|
||||||
public:
|
public:
|
||||||
/// Constructs an instance for managing constants in the given function `f`.
|
/// Constructs an instance for managing constants in the given function `f`.
|
||||||
/// Constants tracked by this instance will be moved to the entry block of
|
/// Constants tracked by this instance will be moved to the entry block of
|
||||||
|
@ -47,7 +48,7 @@ public:
|
||||||
/// This instance does not proactively walk the operations inside `f`;
|
/// This instance does not proactively walk the operations inside `f`;
|
||||||
/// instead, users must invoke the following methods to manually handle each
|
/// instead, users must invoke the following methods to manually handle each
|
||||||
/// operation of interest.
|
/// operation of interest.
|
||||||
FoldHelper(Function *f);
|
OperationFolder(Function *f) : function(f) {}
|
||||||
|
|
||||||
/// Tries to perform folding on the given `op`, including unifying
|
/// Tries to perform folding on the given `op`, including unifying
|
||||||
/// deduplicated constants. If successful, calls `preReplaceAction` (if
|
/// deduplicated constants. If successful, calls `preReplaceAction` (if
|
||||||
|
@ -59,13 +60,17 @@ public:
|
||||||
std::function<void(Operation *)> preReplaceAction = {});
|
std::function<void(Operation *)> preReplaceAction = {});
|
||||||
|
|
||||||
/// Notifies that the given constant `op` should be remove from this
|
/// Notifies that the given constant `op` should be remove from this
|
||||||
/// FoldHelper's internal bookkeeping.
|
/// OperationFolder's internal bookkeeping.
|
||||||
///
|
///
|
||||||
/// Note: this method must be called if a constant op is to be deleted
|
/// Note: this method must be called if a constant op is to be deleted
|
||||||
/// externally to this FoldHelper. `op` must be a constant op.
|
/// externally to this OperationFolder. `op` must be a constant op.
|
||||||
void notifyRemoval(Operation *op);
|
void notifyRemoval(Operation *op);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/// Tries to perform folding on the given `op`. If successful, populates
|
||||||
|
/// `results` with the results of the foldin.
|
||||||
|
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results);
|
||||||
|
|
||||||
/// Tries to deduplicate the given constant and returns success if that can be
|
/// Tries to deduplicate the given constant and returns success if that can be
|
||||||
/// done. This moves the given constant to the top of the entry block if it
|
/// done. This moves the given constant to the top of the entry block if it
|
||||||
/// is first seen. If there is already an existing constant that is the same,
|
/// is first seen. If there is already an existing constant that is the same,
|
||||||
|
|
|
@ -32,12 +32,12 @@ struct TestConstantFold : public FunctionPass<TestConstantFold> {
|
||||||
// All constants in the function post folding.
|
// All constants in the function post folding.
|
||||||
SmallVector<Operation *, 8> existingConstants;
|
SmallVector<Operation *, 8> existingConstants;
|
||||||
|
|
||||||
void foldOperation(Operation *op, FoldHelper &helper);
|
void foldOperation(Operation *op, OperationFolder &helper);
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
void TestConstantFold::foldOperation(Operation *op, FoldHelper &helper) {
|
void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
|
||||||
// Attempt to fold the specified operation, including handling unused or
|
// Attempt to fold the specified operation, including handling unused or
|
||||||
// duplicated constants.
|
// duplicated constants.
|
||||||
if (succeeded(helper.tryToFold(op)))
|
if (succeeded(helper.tryToFold(op)))
|
||||||
|
@ -56,7 +56,7 @@ void TestConstantFold::runOnFunction() {
|
||||||
existingConstants.clear();
|
existingConstants.clear();
|
||||||
|
|
||||||
auto &f = getFunction();
|
auto &f = getFunction();
|
||||||
FoldHelper helper(&f);
|
OperationFolder helper(&f);
|
||||||
|
|
||||||
// Collect and fold the operations within the function.
|
// Collect and fold the operations within the function.
|
||||||
SmallVector<Operation *, 8> ops;
|
SmallVector<Operation *, 8> ops;
|
||||||
|
|
|
@ -29,11 +29,13 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
FoldHelper::FoldHelper(Function *f) : function(f) {}
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OperationFolder
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
FoldHelper::tryToFold(Operation *op,
|
OperationFolder::tryToFold(Operation *op,
|
||||||
std::function<void(Operation *)> preReplaceAction) {
|
std::function<void(Operation *)> preReplaceAction) {
|
||||||
assert(op->getFunction() == function &&
|
assert(op->getFunction() == function &&
|
||||||
"cannot constant fold op from another function");
|
"cannot constant fold op from another function");
|
||||||
|
|
||||||
|
@ -52,8 +54,37 @@ FoldHelper::tryToFold(Operation *op,
|
||||||
return tryToUnify(op);
|
return tryToUnify(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to fold the operation.
|
||||||
|
SmallVector<Value *, 8> results;
|
||||||
|
if (failed(tryToFold(op, results)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Constant folding succeeded. We will start replacing this op's uses and
|
||||||
|
// eventually erase this op. Invoke the callback provided by the caller to
|
||||||
|
// perform any pre-replacement action.
|
||||||
|
if (preReplaceAction)
|
||||||
|
preReplaceAction(op);
|
||||||
|
|
||||||
|
// Check to see if the operation was just updated in place.
|
||||||
|
if (results.empty())
|
||||||
|
return success();
|
||||||
|
|
||||||
|
// Otherwise, replace all of the result values and erase the operation.
|
||||||
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
||||||
|
op->getResult(i)->replaceAllUsesWith(results[i]);
|
||||||
|
op->erase();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tries to perform folding on the given `op`. If successful, populates
|
||||||
|
/// `results` with the results of the foldin.
|
||||||
|
LogicalResult OperationFolder::tryToFold(Operation *op,
|
||||||
|
SmallVectorImpl<Value *> &results) {
|
||||||
|
assert(op->getFunction() == function &&
|
||||||
|
"cannot constant fold op from another function");
|
||||||
|
|
||||||
SmallVector<Attribute, 8> operandConstants;
|
SmallVector<Attribute, 8> operandConstants;
|
||||||
SmallVector<OpFoldResult, 8> results;
|
SmallVector<OpFoldResult, 8> foldResults;
|
||||||
|
|
||||||
// Check to see if any operands to the operation is constant and whether
|
// Check to see if any operands to the operation is constant and whether
|
||||||
// the operation knows how to constant fold itself.
|
// the operation knows how to constant fold itself.
|
||||||
|
@ -70,38 +101,29 @@ FoldHelper::tryToFold(Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to constant fold the operation.
|
// Attempt to constant fold the operation.
|
||||||
if (failed(op->fold(operandConstants, results)))
|
if (failed(op->fold(operandConstants, foldResults)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Constant folding succeeded. We will start replacing this op's uses and
|
|
||||||
// eventually erase this op. Invoke the callback provided by the caller to
|
|
||||||
// perform any pre-replacement action.
|
|
||||||
if (preReplaceAction)
|
|
||||||
preReplaceAction(op);
|
|
||||||
|
|
||||||
// Check to see if the operation was just updated in place.
|
// Check to see if the operation was just updated in place.
|
||||||
if (results.empty())
|
if (foldResults.empty())
|
||||||
return success();
|
return success();
|
||||||
assert(results.size() == op->getNumResults());
|
assert(foldResults.size() == op->getNumResults());
|
||||||
|
|
||||||
// Create the result constants and replace the results.
|
// Create the result constants and replace the results.
|
||||||
FuncBuilder builder(op);
|
FuncBuilder builder(op);
|
||||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||||
auto *res = op->getResult(i);
|
assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
|
||||||
if (res->use_empty()) // Ignore dead uses.
|
|
||||||
continue;
|
|
||||||
assert(!results[i].isNull() && "expected valid OpFoldResult");
|
|
||||||
|
|
||||||
// Check if the result was an SSA value.
|
// Check if the result was an SSA value.
|
||||||
if (auto *repl = results[i].dyn_cast<Value *>()) {
|
if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
|
||||||
if (repl != res)
|
results.emplace_back(repl);
|
||||||
res->replaceAllUsesWith(repl);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we already have a canonicalized version of this constant, just reuse
|
// If we already have a canonicalized version of this constant, just reuse
|
||||||
// it. Otherwise create a new one.
|
// it. Otherwise create a new one.
|
||||||
Attribute attrRepl = results[i].get<Attribute>();
|
Attribute attrRepl = foldResults[i].get<Attribute>();
|
||||||
|
auto *res = op->getResult(i);
|
||||||
auto &constInst =
|
auto &constInst =
|
||||||
uniquedConstants[std::make_pair(attrRepl, res->getType())];
|
uniquedConstants[std::make_pair(attrRepl, res->getType())];
|
||||||
if (!constInst) {
|
if (!constInst) {
|
||||||
|
@ -113,14 +135,13 @@ FoldHelper::tryToFold(Operation *op,
|
||||||
constInst = newOp.getOperation();
|
constInst = newOp.getOperation();
|
||||||
moveConstantToEntryBlock(constInst);
|
moveConstantToEntryBlock(constInst);
|
||||||
}
|
}
|
||||||
res->replaceAllUsesWith(constInst->getResult(0));
|
results.push_back(constInst->getResult(0));
|
||||||
}
|
}
|
||||||
op->erase();
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void FoldHelper::notifyRemoval(Operation *op) {
|
void OperationFolder::notifyRemoval(Operation *op) {
|
||||||
assert(op->getFunction() == function &&
|
assert(op->getFunction() == function &&
|
||||||
"cannot remove constant from another function");
|
"cannot remove constant from another function");
|
||||||
|
|
||||||
|
@ -134,7 +155,7 @@ void FoldHelper::notifyRemoval(Operation *op) {
|
||||||
uniquedConstants.erase(it);
|
uniquedConstants.erase(it);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult FoldHelper::tryToUnify(Operation *op) {
|
LogicalResult OperationFolder::tryToUnify(Operation *op) {
|
||||||
Attribute constValue;
|
Attribute constValue;
|
||||||
matchPattern(op, m_Constant(&constValue));
|
matchPattern(op, m_Constant(&constValue));
|
||||||
assert(constValue);
|
assert(constValue);
|
||||||
|
@ -163,7 +184,7 @@ LogicalResult FoldHelper::tryToUnify(Operation *op) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
void FoldHelper::moveConstantToEntryBlock(Operation *op) {
|
void OperationFolder::moveConstantToEntryBlock(Operation *op) {
|
||||||
// Insert at the very top of the entry block.
|
// Insert at the very top of the entry block.
|
||||||
auto &entryBB = function->front();
|
auto &entryBB = function->front();
|
||||||
op->moveBefore(&entryBB, entryBB.begin());
|
op->moveBefore(&entryBB, entryBB.begin());
|
||||||
|
|
|
@ -143,7 +143,7 @@ private:
|
||||||
/// Perform the rewrites.
|
/// Perform the rewrites.
|
||||||
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
||||||
Function *fn = getFunction();
|
Function *fn = getFunction();
|
||||||
FoldHelper helper(fn);
|
OperationFolder helper(fn);
|
||||||
|
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -166,8 +166,8 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
||||||
// If the operation has no side effects, and no users, then it is
|
// If the operation has no side effects, and no users, then it is
|
||||||
// trivially dead - remove it.
|
// trivially dead - remove it.
|
||||||
if (op->hasNoSideEffect() && op->use_empty()) {
|
if (op->hasNoSideEffect() && op->use_empty()) {
|
||||||
// Be careful to update bookkeeping in FoldHelper to keep consistency if
|
// Be careful to update bookkeeping in OperationFolder to keep
|
||||||
// this is a constant op.
|
// consistency if this is a constant op.
|
||||||
helper.notifyRemoval(op);
|
helper.notifyRemoval(op);
|
||||||
op->erase();
|
op->erase();
|
||||||
continue;
|
continue;
|
||||||
|
|
Loading…
Reference in New Issue